Skip to main content

aver/
resolver.rs

1/// Compile-time variable resolution pass.
2///
3/// After parsing and before interpretation, this pass walks each `FnDef` body
4/// and replaces `Expr::Ident(name)` with `Expr::Resolved(depth, slot)` for
5/// variables that are local to the function (parameters + bindings).
6///
7/// Global/namespace identifiers are left as `Expr::Ident` — the VM
8/// falls back to HashMap lookup for those.
9///
10/// Only top-level `FnDef` bodies are resolved. Top-level `Stmt` items (globals,
11/// REPL) are not touched.
12use std::collections::HashMap;
13use std::sync::Arc as Rc;
14
15use crate::ast::*;
16
17/// Run the resolver on all top-level function definitions.
18pub fn resolve_program(items: &mut [TopLevel]) {
19    for item in items.iter_mut() {
20        if let TopLevel::FnDef(fd) = item {
21            resolve_fn(fd);
22        }
23    }
24}
25
26/// Resolve a single function definition.
27fn resolve_fn(fd: &mut FnDef) {
28    let mut local_slots: HashMap<String, u16> = HashMap::new();
29    let mut next_slot: u16 = 0;
30
31    // Params get slots 0..N-1
32    for (param_name, _) in &fd.params {
33        local_slots.insert(param_name.clone(), next_slot);
34        next_slot += 1;
35    }
36
37    // Scan body for val/var bindings to pre-allocate slots
38    collect_binding_slots(fd.body.stmts(), &mut local_slots, &mut next_slot);
39
40    // Resolve expressions in the body
41    let mut body = fd.body.as_ref().clone();
42    resolve_stmts(body.stmts_mut(), &local_slots);
43    fd.body = Rc::new(body);
44
45    fd.resolution = Some(FnResolution {
46        local_count: next_slot,
47        local_slots: Rc::new(local_slots),
48    });
49}
50
51/// Collect all binding names from a statement list and assign slots.
52/// This handles match arms recursively (pattern bindings get slots too).
53fn collect_binding_slots(
54    stmts: &[Stmt],
55    local_slots: &mut HashMap<String, u16>,
56    next_slot: &mut u16,
57) {
58    for stmt in stmts {
59        match stmt {
60            Stmt::Binding(name, _, expr) => {
61                if !local_slots.contains_key(name) {
62                    local_slots.insert(name.clone(), *next_slot);
63                    *next_slot += 1;
64                }
65                collect_expr_bindings(expr, local_slots, next_slot);
66            }
67            Stmt::Expr(expr) => {
68                collect_expr_bindings(expr, local_slots, next_slot);
69            }
70        }
71    }
72}
73
74/// Collect pattern bindings from match expressions inside an expression tree.
75fn collect_expr_bindings(
76    expr: &Spanned<Expr>,
77    local_slots: &mut HashMap<String, u16>,
78    next_slot: &mut u16,
79) {
80    match &expr.node {
81        Expr::Match { subject, arms } => {
82            collect_expr_bindings(subject, local_slots, next_slot);
83            for arm in arms {
84                collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
85                collect_expr_bindings(&arm.body, local_slots, next_slot);
86            }
87        }
88        Expr::BinOp(_, left, right) => {
89            collect_expr_bindings(left, local_slots, next_slot);
90            collect_expr_bindings(right, local_slots, next_slot);
91        }
92        Expr::FnCall(func, args) => {
93            collect_expr_bindings(func, local_slots, next_slot);
94            for arg in args {
95                collect_expr_bindings(arg, local_slots, next_slot);
96            }
97        }
98        Expr::ErrorProp(inner) => {
99            collect_expr_bindings(inner, local_slots, next_slot);
100        }
101        Expr::Constructor(_, Some(inner)) => {
102            collect_expr_bindings(inner, local_slots, next_slot);
103        }
104        Expr::List(elements) => {
105            for elem in elements {
106                collect_expr_bindings(elem, local_slots, next_slot);
107            }
108        }
109        Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
110            for item in items {
111                collect_expr_bindings(item, local_slots, next_slot);
112            }
113        }
114        Expr::MapLiteral(entries) => {
115            for (key, value) in entries {
116                collect_expr_bindings(key, local_slots, next_slot);
117                collect_expr_bindings(value, local_slots, next_slot);
118            }
119        }
120        Expr::InterpolatedStr(parts) => {
121            for part in parts {
122                if let StrPart::Parsed(e) = part {
123                    collect_expr_bindings(e, local_slots, next_slot);
124                }
125            }
126        }
127        Expr::RecordCreate { fields, .. } => {
128            for (_, expr) in fields {
129                collect_expr_bindings(expr, local_slots, next_slot);
130            }
131        }
132        Expr::RecordUpdate { base, updates, .. } => {
133            collect_expr_bindings(base, local_slots, next_slot);
134            for (_, expr) in updates {
135                collect_expr_bindings(expr, local_slots, next_slot);
136            }
137        }
138        Expr::Attr(obj, _) => {
139            collect_expr_bindings(obj, local_slots, next_slot);
140        }
141        Expr::TailCall(boxed) => {
142            for arg in &boxed.1 {
143                collect_expr_bindings(arg, local_slots, next_slot);
144            }
145        }
146        // Leaves — no bindings to collect
147        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
148    }
149}
150
151/// Assign slots for names introduced by a pattern.
152fn collect_pattern_bindings(
153    pattern: &Pattern,
154    local_slots: &mut HashMap<String, u16>,
155    next_slot: &mut u16,
156) {
157    match pattern {
158        Pattern::Ident(name) => {
159            if !local_slots.contains_key(name) {
160                local_slots.insert(name.clone(), *next_slot);
161                *next_slot += 1;
162            }
163        }
164        Pattern::Cons(head, tail) => {
165            for name in [head, tail] {
166                if name != "_" && !local_slots.contains_key(name) {
167                    local_slots.insert(name.clone(), *next_slot);
168                    *next_slot += 1;
169                }
170            }
171        }
172        Pattern::Constructor(_, bindings) => {
173            for name in bindings {
174                if name != "_" && !local_slots.contains_key(name) {
175                    local_slots.insert(name.clone(), *next_slot);
176                    *next_slot += 1;
177                }
178            }
179        }
180        Pattern::Tuple(items) => {
181            for item in items {
182                collect_pattern_bindings(item, local_slots, next_slot);
183            }
184        }
185        Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
186    }
187}
188
189/// Resolve `Expr::Ident` → `Expr::Resolved` for locals in an expression.
190fn resolve_expr(expr: &mut Spanned<Expr>, local_slots: &HashMap<String, u16>) {
191    match &mut expr.node {
192        Expr::Ident(name) => {
193            if let Some(&slot) = local_slots.get(name) {
194                expr.node = Expr::Resolved(slot);
195            }
196            // else: global/namespace — leave as Ident for HashMap fallback
197        }
198        Expr::Resolved(_) | Expr::Literal(_) => {}
199        Expr::Attr(obj, _) => {
200            resolve_expr(obj, local_slots);
201        }
202        Expr::FnCall(func, args) => {
203            resolve_expr(func, local_slots);
204            for arg in args {
205                resolve_expr(arg, local_slots);
206            }
207        }
208        Expr::BinOp(_, left, right) => {
209            resolve_expr(left, local_slots);
210            resolve_expr(right, local_slots);
211        }
212        Expr::Match { subject, arms } => {
213            resolve_expr(subject, local_slots);
214            for arm in arms {
215                resolve_expr(&mut arm.body, local_slots);
216            }
217        }
218        Expr::Constructor(_, Some(inner)) => {
219            resolve_expr(inner, local_slots);
220        }
221        Expr::Constructor(_, None) => {}
222        Expr::ErrorProp(inner) => {
223            resolve_expr(inner, local_slots);
224        }
225        Expr::InterpolatedStr(parts) => {
226            for part in parts {
227                if let StrPart::Parsed(e) = part {
228                    resolve_expr(e, local_slots);
229                }
230            }
231        }
232        Expr::List(elements) => {
233            for elem in elements {
234                resolve_expr(elem, local_slots);
235            }
236        }
237        Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
238            for item in items {
239                resolve_expr(item, local_slots);
240            }
241        }
242        Expr::MapLiteral(entries) => {
243            for (key, value) in entries {
244                resolve_expr(key, local_slots);
245                resolve_expr(value, local_slots);
246            }
247        }
248        Expr::RecordCreate { fields, .. } => {
249            for (_, expr) in fields {
250                resolve_expr(expr, local_slots);
251            }
252        }
253        Expr::RecordUpdate { base, updates, .. } => {
254            resolve_expr(base, local_slots);
255            for (_, expr) in updates {
256                resolve_expr(expr, local_slots);
257            }
258        }
259        Expr::TailCall(boxed) => {
260            for arg in &mut boxed.1 {
261                resolve_expr(arg, local_slots);
262            }
263        }
264    }
265}
266
267/// Resolve expressions inside statements.
268fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
269    for stmt in stmts {
270        match stmt {
271            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
272                resolve_expr(expr, local_slots);
273            }
274        }
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn resolves_param_to_slot() {
284        let mut fd = FnDef {
285            name: "add".to_string(),
286            line: 1,
287            params: vec![
288                ("a".to_string(), "Int".to_string()),
289                ("b".to_string(), "Int".to_string()),
290            ],
291            return_type: "Int".to_string(),
292            effects: vec![],
293            desc: None,
294            body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::BinOp(
295                BinOp::Add,
296                Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
297                Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
298            )))),
299            resolution: None,
300        };
301        resolve_fn(&mut fd);
302        let res = fd.resolution.as_ref().unwrap();
303        assert_eq!(res.local_slots["a"], 0);
304        assert_eq!(res.local_slots["b"], 1);
305        assert_eq!(res.local_count, 2);
306
307        match fd.body.tail_expr() {
308            Some(Spanned {
309                node: Expr::BinOp(_, left, right),
310                ..
311            }) => {
312                assert_eq!(left.node, Expr::Resolved(0));
313                assert_eq!(right.node, Expr::Resolved(1));
314            }
315            other => panic!("unexpected body: {:?}", other),
316        }
317    }
318
319    #[test]
320    fn leaves_globals_as_ident() {
321        let mut fd = FnDef {
322            name: "f".to_string(),
323            line: 1,
324            params: vec![("x".to_string(), "Int".to_string())],
325            return_type: "Int".to_string(),
326            effects: vec![],
327            desc: None,
328            body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::FnCall(
329                Box::new(Spanned::bare(Expr::Ident("Console".to_string()))),
330                vec![Spanned::bare(Expr::Ident("x".to_string()))],
331            )))),
332            resolution: None,
333        };
334        resolve_fn(&mut fd);
335        match fd.body.tail_expr() {
336            Some(Spanned {
337                node: Expr::FnCall(func, args),
338                ..
339            }) => {
340                assert_eq!(func.node, Expr::Ident("Console".to_string()));
341                assert_eq!(args[0].node, Expr::Resolved(0));
342            }
343            other => panic!("unexpected body: {:?}", other),
344        }
345    }
346
347    #[test]
348    fn resolves_val_in_block_body() {
349        let mut fd = FnDef {
350            name: "f".to_string(),
351            line: 1,
352            params: vec![("x".to_string(), "Int".to_string())],
353            return_type: "Int".to_string(),
354            effects: vec![],
355            desc: None,
356            body: Rc::new(FnBody::Block(vec![
357                Stmt::Binding(
358                    "y".to_string(),
359                    None,
360                    Spanned::bare(Expr::BinOp(
361                        BinOp::Add,
362                        Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
363                        Box::new(Spanned::bare(Expr::Literal(Literal::Int(1)))),
364                    )),
365                ),
366                Stmt::Expr(Spanned::bare(Expr::Ident("y".to_string()))),
367            ])),
368            resolution: None,
369        };
370        resolve_fn(&mut fd);
371        let res = fd.resolution.as_ref().unwrap();
372        assert_eq!(res.local_slots["x"], 0);
373        assert_eq!(res.local_slots["y"], 1);
374        assert_eq!(res.local_count, 2);
375
376        let stmts = fd.body.stmts();
377        // val y = x + 1  →  val y = Resolved(0,0) + 1
378        match &stmts[0] {
379            Stmt::Binding(
380                _,
381                _,
382                Spanned {
383                    node: Expr::BinOp(_, left, _),
384                    ..
385                },
386            ) => {
387                assert_eq!(left.node, Expr::Resolved(0));
388            }
389            other => panic!("unexpected stmt: {:?}", other),
390        }
391        // y  →  Resolved(0,1)
392        match &stmts[1] {
393            Stmt::Expr(Spanned {
394                node: Expr::Resolved(1),
395                ..
396            }) => {}
397            other => panic!("unexpected stmt: {:?}", other),
398        }
399    }
400
401    #[test]
402    fn resolves_match_pattern_bindings() {
403        // fn f(x: Int) -> Int / match x: Result.Ok(v) -> v, _ -> 0
404        let mut fd = FnDef {
405            name: "f".to_string(),
406            line: 1,
407            params: vec![("x".to_string(), "Int".to_string())],
408            return_type: "Int".to_string(),
409            effects: vec![],
410            desc: None,
411            body: Rc::new(FnBody::from_expr(Spanned::new(
412                Expr::Match {
413                    subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
414                    arms: vec![
415                        MatchArm {
416                            pattern: Pattern::Constructor(
417                                "Result.Ok".to_string(),
418                                vec!["v".to_string()],
419                            ),
420                            body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
421                        },
422                        MatchArm {
423                            pattern: Pattern::Wildcard,
424                            body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
425                        },
426                    ],
427                },
428                1,
429            ))),
430            resolution: None,
431        };
432        resolve_fn(&mut fd);
433        let res = fd.resolution.as_ref().unwrap();
434        // x=0, v=1
435        assert_eq!(res.local_slots["v"], 1);
436
437        match fd.body.tail_expr() {
438            Some(Spanned {
439                node: Expr::Match { arms, .. },
440                ..
441            }) => {
442                assert_eq!(arms[0].body.node, Expr::Resolved(1));
443            }
444            other => panic!("unexpected body: {:?}", other),
445        }
446    }
447
448    #[test]
449    fn resolves_match_pattern_bindings_inside_binding_initializer() {
450        let mut fd = FnDef {
451            name: "f".to_string(),
452            line: 1,
453            params: vec![("x".to_string(), "Int".to_string())],
454            return_type: "Int".to_string(),
455            effects: vec![],
456            desc: None,
457            body: Rc::new(FnBody::Block(vec![
458                Stmt::Binding(
459                    "result".to_string(),
460                    None,
461                    Spanned::bare(Expr::Match {
462                        subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
463                        arms: vec![
464                            MatchArm {
465                                pattern: Pattern::Constructor(
466                                    "Option.Some".to_string(),
467                                    vec!["v".to_string()],
468                                ),
469                                body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
470                            },
471                            MatchArm {
472                                pattern: Pattern::Wildcard,
473                                body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
474                            },
475                        ],
476                    }),
477                ),
478                Stmt::Expr(Spanned::bare(Expr::Ident("result".to_string()))),
479            ])),
480            resolution: None,
481        };
482
483        resolve_fn(&mut fd);
484        let res = fd.resolution.as_ref().unwrap();
485        assert_eq!(res.local_slots["x"], 0);
486        assert_eq!(res.local_slots["result"], 1);
487        assert_eq!(res.local_slots["v"], 2);
488
489        let stmts = fd.body.stmts();
490        match &stmts[0] {
491            Stmt::Binding(
492                _,
493                _,
494                Spanned {
495                    node: Expr::Match { arms, .. },
496                    ..
497                },
498            ) => {
499                assert_eq!(arms[0].body.node, Expr::Resolved(2));
500            }
501            other => panic!("unexpected stmt: {:?}", other),
502        }
503
504        match &stmts[1] {
505            Stmt::Expr(Spanned {
506                node: Expr::Resolved(1),
507                ..
508            }) => {}
509            other => panic!("unexpected stmt: {:?}", other),
510        }
511    }
512}