Skip to main content

aver/
resolver.rs

1/// Compile-time variable resolution pass.
2///
3/// After parsing and before interpretation, this pass walks each `FnDef` body
4/// and replaces `Expr::Ident(name)` with `Expr::Resolved(depth, slot)` for
5/// variables that are local to the function (parameters + bindings).
6///
7/// Global/namespace identifiers are left as `Expr::Ident` — the interpreter
8/// falls back to HashMap lookup for those.
9///
10/// Only top-level `FnDef` bodies are resolved. Top-level `Stmt` items (globals,
11/// REPL) are not touched.
12use std::collections::HashMap;
13use std::rc::Rc;
14
15use crate::ast::*;
16
17/// Run the resolver on all top-level function definitions.
18pub fn resolve_program(items: &mut [TopLevel]) {
19    for item in items.iter_mut() {
20        if let TopLevel::FnDef(fd) = item {
21            resolve_fn(fd);
22        }
23    }
24}
25
26/// Resolve a single function definition.
27fn resolve_fn(fd: &mut FnDef) {
28    let mut local_slots: HashMap<String, u16> = HashMap::new();
29    let mut next_slot: u16 = 0;
30
31    // Params get slots 0..N-1
32    for (param_name, _) in &fd.params {
33        local_slots.insert(param_name.clone(), next_slot);
34        next_slot += 1;
35    }
36
37    // Scan body for val/var bindings to pre-allocate slots
38    collect_binding_slots(fd.body.stmts(), &mut local_slots, &mut next_slot);
39
40    // Resolve expressions in the body
41    let mut body = fd.body.as_ref().clone();
42    resolve_stmts(body.stmts_mut(), &local_slots);
43    fd.body = Rc::new(body);
44
45    fd.resolution = Some(FnResolution {
46        local_count: next_slot,
47        local_slots,
48    });
49}
50
51/// Collect all binding names from a statement list and assign slots.
52/// This handles match arms recursively (pattern bindings get slots too).
53fn collect_binding_slots(
54    stmts: &[Stmt],
55    local_slots: &mut HashMap<String, u16>,
56    next_slot: &mut u16,
57) {
58    for stmt in stmts {
59        match stmt {
60            Stmt::Binding(name, _, _) => {
61                if !local_slots.contains_key(name) {
62                    local_slots.insert(name.clone(), *next_slot);
63                    *next_slot += 1;
64                }
65            }
66            Stmt::Expr(expr) => {
67                collect_expr_bindings(expr, local_slots, next_slot);
68            }
69        }
70    }
71}
72
73/// Collect pattern bindings from match expressions inside an expression tree.
74fn collect_expr_bindings(expr: &Expr, local_slots: &mut HashMap<String, u16>, next_slot: &mut u16) {
75    match expr {
76        Expr::Match { subject, arms, .. } => {
77            collect_expr_bindings(subject, local_slots, next_slot);
78            for arm in arms {
79                collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
80                collect_expr_bindings(&arm.body, local_slots, next_slot);
81            }
82        }
83        Expr::BinOp(_, left, right) => {
84            collect_expr_bindings(left, local_slots, next_slot);
85            collect_expr_bindings(right, local_slots, next_slot);
86        }
87        Expr::FnCall(func, args) => {
88            collect_expr_bindings(func, local_slots, next_slot);
89            for arg in args {
90                collect_expr_bindings(arg, local_slots, next_slot);
91            }
92        }
93        Expr::ErrorProp(inner) => {
94            collect_expr_bindings(inner, local_slots, next_slot);
95        }
96        Expr::Constructor(_, Some(inner)) => {
97            collect_expr_bindings(inner, local_slots, next_slot);
98        }
99        Expr::List(elements) => {
100            for elem in elements {
101                collect_expr_bindings(elem, local_slots, next_slot);
102            }
103        }
104        Expr::Tuple(items) => {
105            for item in items {
106                collect_expr_bindings(item, local_slots, next_slot);
107            }
108        }
109        Expr::MapLiteral(entries) => {
110            for (key, value) in entries {
111                collect_expr_bindings(key, local_slots, next_slot);
112                collect_expr_bindings(value, local_slots, next_slot);
113            }
114        }
115        Expr::InterpolatedStr(parts) => {
116            for part in parts {
117                if let StrPart::Parsed(e) = part {
118                    collect_expr_bindings(e, local_slots, next_slot);
119                }
120            }
121        }
122        Expr::RecordCreate { fields, .. } => {
123            for (_, expr) in fields {
124                collect_expr_bindings(expr, local_slots, next_slot);
125            }
126        }
127        Expr::RecordUpdate { base, updates, .. } => {
128            collect_expr_bindings(base, local_slots, next_slot);
129            for (_, expr) in updates {
130                collect_expr_bindings(expr, local_slots, next_slot);
131            }
132        }
133        Expr::Attr(obj, _) => {
134            collect_expr_bindings(obj, local_slots, next_slot);
135        }
136        Expr::TailCall(boxed) => {
137            for arg in &boxed.1 {
138                collect_expr_bindings(arg, local_slots, next_slot);
139            }
140        }
141        // Leaves — no bindings to collect
142        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
143    }
144}
145
146/// Assign slots for names introduced by a pattern.
147fn collect_pattern_bindings(
148    pattern: &Pattern,
149    local_slots: &mut HashMap<String, u16>,
150    next_slot: &mut u16,
151) {
152    match pattern {
153        Pattern::Ident(name) => {
154            if !local_slots.contains_key(name) {
155                local_slots.insert(name.clone(), *next_slot);
156                *next_slot += 1;
157            }
158        }
159        Pattern::Cons(head, tail) => {
160            for name in [head, tail] {
161                if name != "_" && !local_slots.contains_key(name) {
162                    local_slots.insert(name.clone(), *next_slot);
163                    *next_slot += 1;
164                }
165            }
166        }
167        Pattern::Constructor(_, bindings) => {
168            for name in bindings {
169                if name != "_" && !local_slots.contains_key(name) {
170                    local_slots.insert(name.clone(), *next_slot);
171                    *next_slot += 1;
172                }
173            }
174        }
175        Pattern::Tuple(items) => {
176            for item in items {
177                collect_pattern_bindings(item, local_slots, next_slot);
178            }
179        }
180        Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
181    }
182}
183
184/// Resolve `Expr::Ident` → `Expr::Resolved` for locals in an expression.
185fn resolve_expr(expr: &mut Expr, local_slots: &HashMap<String, u16>) {
186    match expr {
187        Expr::Ident(name) => {
188            if let Some(&slot) = local_slots.get(name) {
189                *expr = Expr::Resolved(slot);
190            }
191            // else: global/namespace — leave as Ident for HashMap fallback
192        }
193        Expr::Resolved(_) | Expr::Literal(_) => {}
194        Expr::Attr(obj, _) => {
195            resolve_expr(obj, local_slots);
196        }
197        Expr::FnCall(func, args) => {
198            resolve_expr(func, local_slots);
199            for arg in args {
200                resolve_expr(arg, local_slots);
201            }
202        }
203        Expr::BinOp(_, left, right) => {
204            resolve_expr(left, local_slots);
205            resolve_expr(right, local_slots);
206        }
207        Expr::Match { subject, arms, .. } => {
208            resolve_expr(subject, local_slots);
209            for arm in arms {
210                resolve_expr(&mut arm.body, local_slots);
211            }
212        }
213        Expr::Constructor(_, Some(inner)) => {
214            resolve_expr(inner, local_slots);
215        }
216        Expr::Constructor(_, None) => {}
217        Expr::ErrorProp(inner) => {
218            resolve_expr(inner, local_slots);
219        }
220        Expr::InterpolatedStr(parts) => {
221            for part in parts {
222                if let StrPart::Parsed(e) = part {
223                    resolve_expr(e, local_slots);
224                }
225            }
226        }
227        Expr::List(elements) => {
228            for elem in elements {
229                resolve_expr(elem, local_slots);
230            }
231        }
232        Expr::Tuple(items) => {
233            for item in items {
234                resolve_expr(item, local_slots);
235            }
236        }
237        Expr::MapLiteral(entries) => {
238            for (key, value) in entries {
239                resolve_expr(key, local_slots);
240                resolve_expr(value, local_slots);
241            }
242        }
243        Expr::RecordCreate { fields, .. } => {
244            for (_, expr) in fields {
245                resolve_expr(expr, local_slots);
246            }
247        }
248        Expr::RecordUpdate { base, updates, .. } => {
249            resolve_expr(base, local_slots);
250            for (_, expr) in updates {
251                resolve_expr(expr, local_slots);
252            }
253        }
254        Expr::TailCall(boxed) => {
255            for arg in &mut boxed.1 {
256                resolve_expr(arg, local_slots);
257            }
258        }
259    }
260}
261
262/// Resolve expressions inside statements.
263fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
264    for stmt in stmts {
265        match stmt {
266            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
267                resolve_expr(expr, local_slots);
268            }
269        }
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn resolves_param_to_slot() {
279        let mut fd = FnDef {
280            name: "add".to_string(),
281            line: 1,
282            params: vec![
283                ("a".to_string(), "Int".to_string()),
284                ("b".to_string(), "Int".to_string()),
285            ],
286            return_type: "Int".to_string(),
287            effects: vec![],
288            desc: None,
289            body: Rc::new(FnBody::from_expr(Expr::BinOp(
290                BinOp::Add,
291                Box::new(Expr::Ident("a".to_string())),
292                Box::new(Expr::Ident("b".to_string())),
293            ))),
294            resolution: None,
295        };
296        resolve_fn(&mut fd);
297        let res = fd.resolution.as_ref().unwrap();
298        assert_eq!(res.local_slots["a"], 0);
299        assert_eq!(res.local_slots["b"], 1);
300        assert_eq!(res.local_count, 2);
301
302        match fd.body.tail_expr() {
303            Some(Expr::BinOp(_, left, right)) => {
304                assert_eq!(**left, Expr::Resolved(0));
305                assert_eq!(**right, Expr::Resolved(1));
306            }
307            other => panic!("unexpected body: {:?}", other),
308        }
309    }
310
311    #[test]
312    fn leaves_globals_as_ident() {
313        let mut fd = FnDef {
314            name: "f".to_string(),
315            line: 1,
316            params: vec![("x".to_string(), "Int".to_string())],
317            return_type: "Int".to_string(),
318            effects: vec![],
319            desc: None,
320            body: Rc::new(FnBody::from_expr(Expr::FnCall(
321                Box::new(Expr::Ident("Console".to_string())),
322                vec![Expr::Ident("x".to_string())],
323            ))),
324            resolution: None,
325        };
326        resolve_fn(&mut fd);
327        match fd.body.tail_expr() {
328            Some(Expr::FnCall(func, args)) => {
329                assert_eq!(**func, Expr::Ident("Console".to_string()));
330                assert_eq!(args[0], Expr::Resolved(0));
331            }
332            other => panic!("unexpected body: {:?}", other),
333        }
334    }
335
336    #[test]
337    fn resolves_val_in_block_body() {
338        let mut fd = FnDef {
339            name: "f".to_string(),
340            line: 1,
341            params: vec![("x".to_string(), "Int".to_string())],
342            return_type: "Int".to_string(),
343            effects: vec![],
344            desc: None,
345            body: Rc::new(FnBody::Block(vec![
346                Stmt::Binding(
347                    "y".to_string(),
348                    None,
349                    Expr::BinOp(
350                        BinOp::Add,
351                        Box::new(Expr::Ident("x".to_string())),
352                        Box::new(Expr::Literal(Literal::Int(1))),
353                    ),
354                ),
355                Stmt::Expr(Expr::Ident("y".to_string())),
356            ])),
357            resolution: None,
358        };
359        resolve_fn(&mut fd);
360        let res = fd.resolution.as_ref().unwrap();
361        assert_eq!(res.local_slots["x"], 0);
362        assert_eq!(res.local_slots["y"], 1);
363        assert_eq!(res.local_count, 2);
364
365        let stmts = fd.body.stmts();
366        // val y = x + 1  →  val y = Resolved(0,0) + 1
367        match &stmts[0] {
368            Stmt::Binding(_, _, Expr::BinOp(_, left, _)) => {
369                assert_eq!(**left, Expr::Resolved(0));
370            }
371            other => panic!("unexpected stmt: {:?}", other),
372        }
373        // y  →  Resolved(0,1)
374        match &stmts[1] {
375            Stmt::Expr(Expr::Resolved(1)) => {}
376            other => panic!("unexpected stmt: {:?}", other),
377        }
378    }
379
380    #[test]
381    fn resolves_match_pattern_bindings() {
382        // fn f(x: Int) -> Int / match x: Result.Ok(v) -> v, _ -> 0
383        let mut fd = FnDef {
384            name: "f".to_string(),
385            line: 1,
386            params: vec![("x".to_string(), "Int".to_string())],
387            return_type: "Int".to_string(),
388            effects: vec![],
389            desc: None,
390            body: Rc::new(FnBody::from_expr(Expr::Match {
391                subject: Box::new(Expr::Ident("x".to_string())),
392                arms: vec![
393                    MatchArm {
394                        pattern: Pattern::Constructor(
395                            "Result.Ok".to_string(),
396                            vec!["v".to_string()],
397                        ),
398                        body: Box::new(Expr::Ident("v".to_string())),
399                    },
400                    MatchArm {
401                        pattern: Pattern::Wildcard,
402                        body: Box::new(Expr::Literal(Literal::Int(0))),
403                    },
404                ],
405                line: 1,
406            })),
407            resolution: None,
408        };
409        resolve_fn(&mut fd);
410        let res = fd.resolution.as_ref().unwrap();
411        // x=0, v=1
412        assert_eq!(res.local_slots["v"], 1);
413
414        match fd.body.tail_expr() {
415            Some(Expr::Match { arms, .. }) => {
416                assert_eq!(*arms[0].body, Expr::Resolved(1));
417            }
418            other => panic!("unexpected body: {:?}", other),
419        }
420    }
421}