Skip to main content

aver/
resolver.rs

1/// Compile-time variable resolution pass.
2///
3/// After parsing and before interpretation, this pass walks each `FnDef` body
4/// and replaces `Expr::Ident(name)` with `Expr::Resolved(depth, slot)` for
5/// variables that are local to the function (parameters + bindings).
6///
7/// Global/namespace identifiers are left as `Expr::Ident` — the interpreter
8/// falls back to HashMap lookup for those.
9///
10/// Only top-level `FnDef` bodies are resolved. Top-level `Stmt` items (globals,
11/// REPL) are not touched.
12use std::collections::HashMap;
13use std::sync::Arc as Rc;
14
15use crate::ast::*;
16
17/// Run the resolver on all top-level function definitions.
18pub fn resolve_program(items: &mut [TopLevel]) {
19    for item in items.iter_mut() {
20        if let TopLevel::FnDef(fd) = item {
21            resolve_fn(fd);
22        }
23    }
24}
25
26/// Resolve a single function definition.
27fn resolve_fn(fd: &mut FnDef) {
28    let mut local_slots: HashMap<String, u16> = HashMap::new();
29    let mut next_slot: u16 = 0;
30
31    // Params get slots 0..N-1
32    for (param_name, _) in &fd.params {
33        local_slots.insert(param_name.clone(), next_slot);
34        next_slot += 1;
35    }
36
37    // Scan body for val/var bindings to pre-allocate slots
38    collect_binding_slots(fd.body.stmts(), &mut local_slots, &mut next_slot);
39
40    // Resolve expressions in the body
41    let mut body = fd.body.as_ref().clone();
42    resolve_stmts(body.stmts_mut(), &local_slots);
43    fd.body = Rc::new(body);
44
45    fd.resolution = Some(FnResolution {
46        local_count: next_slot,
47        local_slots: Rc::new(local_slots),
48    });
49}
50
51/// Collect all binding names from a statement list and assign slots.
52/// This handles match arms recursively (pattern bindings get slots too).
53fn collect_binding_slots(
54    stmts: &[Stmt],
55    local_slots: &mut HashMap<String, u16>,
56    next_slot: &mut u16,
57) {
58    for stmt in stmts {
59        match stmt {
60            Stmt::Binding(name, _, _) => {
61                if !local_slots.contains_key(name) {
62                    local_slots.insert(name.clone(), *next_slot);
63                    *next_slot += 1;
64                }
65            }
66            Stmt::Expr(expr) => {
67                collect_expr_bindings(expr, local_slots, next_slot);
68            }
69        }
70    }
71}
72
73/// Collect pattern bindings from match expressions inside an expression tree.
74fn collect_expr_bindings(
75    expr: &Spanned<Expr>,
76    local_slots: &mut HashMap<String, u16>,
77    next_slot: &mut u16,
78) {
79    match &expr.node {
80        Expr::Match { subject, arms } => {
81            collect_expr_bindings(subject, local_slots, next_slot);
82            for arm in arms {
83                collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
84                collect_expr_bindings(&arm.body, local_slots, next_slot);
85            }
86        }
87        Expr::BinOp(_, left, right) => {
88            collect_expr_bindings(left, local_slots, next_slot);
89            collect_expr_bindings(right, local_slots, next_slot);
90        }
91        Expr::FnCall(func, args) => {
92            collect_expr_bindings(func, local_slots, next_slot);
93            for arg in args {
94                collect_expr_bindings(arg, local_slots, next_slot);
95            }
96        }
97        Expr::ErrorProp(inner) => {
98            collect_expr_bindings(inner, local_slots, next_slot);
99        }
100        Expr::Constructor(_, Some(inner)) => {
101            collect_expr_bindings(inner, local_slots, next_slot);
102        }
103        Expr::List(elements) => {
104            for elem in elements {
105                collect_expr_bindings(elem, local_slots, next_slot);
106            }
107        }
108        Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
109            for item in items {
110                collect_expr_bindings(item, local_slots, next_slot);
111            }
112        }
113        Expr::MapLiteral(entries) => {
114            for (key, value) in entries {
115                collect_expr_bindings(key, local_slots, next_slot);
116                collect_expr_bindings(value, local_slots, next_slot);
117            }
118        }
119        Expr::InterpolatedStr(parts) => {
120            for part in parts {
121                if let StrPart::Parsed(e) = part {
122                    collect_expr_bindings(e, local_slots, next_slot);
123                }
124            }
125        }
126        Expr::RecordCreate { fields, .. } => {
127            for (_, expr) in fields {
128                collect_expr_bindings(expr, local_slots, next_slot);
129            }
130        }
131        Expr::RecordUpdate { base, updates, .. } => {
132            collect_expr_bindings(base, local_slots, next_slot);
133            for (_, expr) in updates {
134                collect_expr_bindings(expr, local_slots, next_slot);
135            }
136        }
137        Expr::Attr(obj, _) => {
138            collect_expr_bindings(obj, local_slots, next_slot);
139        }
140        Expr::TailCall(boxed) => {
141            for arg in &boxed.1 {
142                collect_expr_bindings(arg, local_slots, next_slot);
143            }
144        }
145        // Leaves — no bindings to collect
146        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
147    }
148}
149
150/// Assign slots for names introduced by a pattern.
151fn collect_pattern_bindings(
152    pattern: &Pattern,
153    local_slots: &mut HashMap<String, u16>,
154    next_slot: &mut u16,
155) {
156    match pattern {
157        Pattern::Ident(name) => {
158            if !local_slots.contains_key(name) {
159                local_slots.insert(name.clone(), *next_slot);
160                *next_slot += 1;
161            }
162        }
163        Pattern::Cons(head, tail) => {
164            for name in [head, tail] {
165                if name != "_" && !local_slots.contains_key(name) {
166                    local_slots.insert(name.clone(), *next_slot);
167                    *next_slot += 1;
168                }
169            }
170        }
171        Pattern::Constructor(_, bindings) => {
172            for name in bindings {
173                if name != "_" && !local_slots.contains_key(name) {
174                    local_slots.insert(name.clone(), *next_slot);
175                    *next_slot += 1;
176                }
177            }
178        }
179        Pattern::Tuple(items) => {
180            for item in items {
181                collect_pattern_bindings(item, local_slots, next_slot);
182            }
183        }
184        Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
185    }
186}
187
188/// Resolve `Expr::Ident` → `Expr::Resolved` for locals in an expression.
189fn resolve_expr(expr: &mut Spanned<Expr>, local_slots: &HashMap<String, u16>) {
190    match &mut expr.node {
191        Expr::Ident(name) => {
192            if let Some(&slot) = local_slots.get(name) {
193                expr.node = Expr::Resolved(slot);
194            }
195            // else: global/namespace — leave as Ident for HashMap fallback
196        }
197        Expr::Resolved(_) | Expr::Literal(_) => {}
198        Expr::Attr(obj, _) => {
199            resolve_expr(obj, local_slots);
200        }
201        Expr::FnCall(func, args) => {
202            resolve_expr(func, local_slots);
203            for arg in args {
204                resolve_expr(arg, local_slots);
205            }
206        }
207        Expr::BinOp(_, left, right) => {
208            resolve_expr(left, local_slots);
209            resolve_expr(right, local_slots);
210        }
211        Expr::Match { subject, arms } => {
212            resolve_expr(subject, local_slots);
213            for arm in arms {
214                resolve_expr(&mut arm.body, local_slots);
215            }
216        }
217        Expr::Constructor(_, Some(inner)) => {
218            resolve_expr(inner, local_slots);
219        }
220        Expr::Constructor(_, None) => {}
221        Expr::ErrorProp(inner) => {
222            resolve_expr(inner, local_slots);
223        }
224        Expr::InterpolatedStr(parts) => {
225            for part in parts {
226                if let StrPart::Parsed(e) = part {
227                    resolve_expr(e, local_slots);
228                }
229            }
230        }
231        Expr::List(elements) => {
232            for elem in elements {
233                resolve_expr(elem, local_slots);
234            }
235        }
236        Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
237            for item in items {
238                resolve_expr(item, local_slots);
239            }
240        }
241        Expr::MapLiteral(entries) => {
242            for (key, value) in entries {
243                resolve_expr(key, local_slots);
244                resolve_expr(value, local_slots);
245            }
246        }
247        Expr::RecordCreate { fields, .. } => {
248            for (_, expr) in fields {
249                resolve_expr(expr, local_slots);
250            }
251        }
252        Expr::RecordUpdate { base, updates, .. } => {
253            resolve_expr(base, local_slots);
254            for (_, expr) in updates {
255                resolve_expr(expr, local_slots);
256            }
257        }
258        Expr::TailCall(boxed) => {
259            for arg in &mut boxed.1 {
260                resolve_expr(arg, local_slots);
261            }
262        }
263    }
264}
265
266/// Resolve expressions inside statements.
267fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
268    for stmt in stmts {
269        match stmt {
270            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
271                resolve_expr(expr, local_slots);
272            }
273        }
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn resolves_param_to_slot() {
283        let mut fd = FnDef {
284            name: "add".to_string(),
285            line: 1,
286            params: vec![
287                ("a".to_string(), "Int".to_string()),
288                ("b".to_string(), "Int".to_string()),
289            ],
290            return_type: "Int".to_string(),
291            effects: vec![],
292            desc: None,
293            body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::BinOp(
294                BinOp::Add,
295                Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
296                Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
297            )))),
298            resolution: None,
299        };
300        resolve_fn(&mut fd);
301        let res = fd.resolution.as_ref().unwrap();
302        assert_eq!(res.local_slots["a"], 0);
303        assert_eq!(res.local_slots["b"], 1);
304        assert_eq!(res.local_count, 2);
305
306        match fd.body.tail_expr() {
307            Some(Spanned {
308                node: Expr::BinOp(_, left, right),
309                ..
310            }) => {
311                assert_eq!(left.node, Expr::Resolved(0));
312                assert_eq!(right.node, Expr::Resolved(1));
313            }
314            other => panic!("unexpected body: {:?}", other),
315        }
316    }
317
318    #[test]
319    fn leaves_globals_as_ident() {
320        let mut fd = FnDef {
321            name: "f".to_string(),
322            line: 1,
323            params: vec![("x".to_string(), "Int".to_string())],
324            return_type: "Int".to_string(),
325            effects: vec![],
326            desc: None,
327            body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::FnCall(
328                Box::new(Spanned::bare(Expr::Ident("Console".to_string()))),
329                vec![Spanned::bare(Expr::Ident("x".to_string()))],
330            )))),
331            resolution: None,
332        };
333        resolve_fn(&mut fd);
334        match fd.body.tail_expr() {
335            Some(Spanned {
336                node: Expr::FnCall(func, args),
337                ..
338            }) => {
339                assert_eq!(func.node, Expr::Ident("Console".to_string()));
340                assert_eq!(args[0].node, Expr::Resolved(0));
341            }
342            other => panic!("unexpected body: {:?}", other),
343        }
344    }
345
346    #[test]
347    fn resolves_val_in_block_body() {
348        let mut fd = FnDef {
349            name: "f".to_string(),
350            line: 1,
351            params: vec![("x".to_string(), "Int".to_string())],
352            return_type: "Int".to_string(),
353            effects: vec![],
354            desc: None,
355            body: Rc::new(FnBody::Block(vec![
356                Stmt::Binding(
357                    "y".to_string(),
358                    None,
359                    Spanned::bare(Expr::BinOp(
360                        BinOp::Add,
361                        Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
362                        Box::new(Spanned::bare(Expr::Literal(Literal::Int(1)))),
363                    )),
364                ),
365                Stmt::Expr(Spanned::bare(Expr::Ident("y".to_string()))),
366            ])),
367            resolution: None,
368        };
369        resolve_fn(&mut fd);
370        let res = fd.resolution.as_ref().unwrap();
371        assert_eq!(res.local_slots["x"], 0);
372        assert_eq!(res.local_slots["y"], 1);
373        assert_eq!(res.local_count, 2);
374
375        let stmts = fd.body.stmts();
376        // val y = x + 1  →  val y = Resolved(0,0) + 1
377        match &stmts[0] {
378            Stmt::Binding(
379                _,
380                _,
381                Spanned {
382                    node: Expr::BinOp(_, left, _),
383                    ..
384                },
385            ) => {
386                assert_eq!(left.node, Expr::Resolved(0));
387            }
388            other => panic!("unexpected stmt: {:?}", other),
389        }
390        // y  →  Resolved(0,1)
391        match &stmts[1] {
392            Stmt::Expr(Spanned {
393                node: Expr::Resolved(1),
394                ..
395            }) => {}
396            other => panic!("unexpected stmt: {:?}", other),
397        }
398    }
399
400    #[test]
401    fn resolves_match_pattern_bindings() {
402        // fn f(x: Int) -> Int / match x: Result.Ok(v) -> v, _ -> 0
403        let mut fd = FnDef {
404            name: "f".to_string(),
405            line: 1,
406            params: vec![("x".to_string(), "Int".to_string())],
407            return_type: "Int".to_string(),
408            effects: vec![],
409            desc: None,
410            body: Rc::new(FnBody::from_expr(Spanned::new(
411                Expr::Match {
412                    subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
413                    arms: vec![
414                        MatchArm {
415                            pattern: Pattern::Constructor(
416                                "Result.Ok".to_string(),
417                                vec!["v".to_string()],
418                            ),
419                            body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
420                        },
421                        MatchArm {
422                            pattern: Pattern::Wildcard,
423                            body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
424                        },
425                    ],
426                },
427                1,
428            ))),
429            resolution: None,
430        };
431        resolve_fn(&mut fd);
432        let res = fd.resolution.as_ref().unwrap();
433        // x=0, v=1
434        assert_eq!(res.local_slots["v"], 1);
435
436        match fd.body.tail_expr() {
437            Some(Spanned {
438                node: Expr::Match { arms, .. },
439                ..
440            }) => {
441                assert_eq!(arms[0].body.node, Expr::Resolved(1));
442            }
443            other => panic!("unexpected body: {:?}", other),
444        }
445    }
446}