Skip to main content

aver/
resolver.rs

1/// Compile-time variable resolution pass.
2///
3/// After parsing and before interpretation, this pass walks each `FnDef` body
4/// and replaces `Expr::Ident(name)` with `Expr::Resolved(depth, slot)` for
5/// variables that are local to the function (parameters + bindings).
6///
7/// Global/namespace identifiers are left as `Expr::Ident` — the interpreter
8/// falls back to HashMap lookup for those.
9///
10/// Only top-level `FnDef` bodies are resolved. Top-level `Stmt` items (globals,
11/// REPL) are not touched.
12use std::collections::HashMap;
13use std::rc::Rc;
14
15use crate::ast::*;
16
17/// Run the resolver on all top-level function definitions.
18pub fn resolve_program(items: &mut [TopLevel]) {
19    for item in items.iter_mut() {
20        if let TopLevel::FnDef(fd) = item {
21            resolve_fn(fd);
22        }
23    }
24}
25
26/// Resolve a single function definition.
27fn resolve_fn(fd: &mut FnDef) {
28    let mut local_slots: HashMap<String, u16> = HashMap::new();
29    let mut next_slot: u16 = 0;
30
31    // Params get slots 0..N-1
32    for (param_name, _) in &fd.params {
33        local_slots.insert(param_name.clone(), next_slot);
34        next_slot += 1;
35    }
36
37    // Scan body for val/var bindings to pre-allocate slots
38    match fd.body.as_ref() {
39        FnBody::Expr(expr) => {
40            collect_expr_bindings(expr, &mut local_slots, &mut next_slot);
41        }
42        FnBody::Block(stmts) => {
43            collect_binding_slots(stmts, &mut local_slots, &mut next_slot);
44        }
45    }
46
47    // Resolve expressions in the body
48    let mut body = fd.body.as_ref().clone();
49    match &mut body {
50        FnBody::Expr(expr) => {
51            resolve_expr(expr, &local_slots);
52        }
53        FnBody::Block(stmts) => {
54            resolve_stmts(stmts, &local_slots);
55        }
56    }
57    fd.body = Rc::new(body);
58
59    fd.resolution = Some(FnResolution {
60        local_count: next_slot,
61        local_slots,
62    });
63}
64
65/// Collect all binding names from a statement list and assign slots.
66/// This handles match arms recursively (pattern bindings get slots too).
67fn collect_binding_slots(
68    stmts: &[Stmt],
69    local_slots: &mut HashMap<String, u16>,
70    next_slot: &mut u16,
71) {
72    for stmt in stmts {
73        match stmt {
74            Stmt::Binding(name, _, _) => {
75                if !local_slots.contains_key(name) {
76                    local_slots.insert(name.clone(), *next_slot);
77                    *next_slot += 1;
78                }
79            }
80            Stmt::Expr(expr) => {
81                collect_expr_bindings(expr, local_slots, next_slot);
82            }
83        }
84    }
85}
86
87/// Collect pattern bindings from match expressions inside an expression tree.
88fn collect_expr_bindings(expr: &Expr, local_slots: &mut HashMap<String, u16>, next_slot: &mut u16) {
89    match expr {
90        Expr::Match { subject, arms, .. } => {
91            collect_expr_bindings(subject, local_slots, next_slot);
92            for arm in arms {
93                collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
94                collect_expr_bindings(&arm.body, local_slots, next_slot);
95            }
96        }
97        Expr::BinOp(_, left, right) => {
98            collect_expr_bindings(left, local_slots, next_slot);
99            collect_expr_bindings(right, local_slots, next_slot);
100        }
101        Expr::FnCall(func, args) => {
102            collect_expr_bindings(func, local_slots, next_slot);
103            for arg in args {
104                collect_expr_bindings(arg, local_slots, next_slot);
105            }
106        }
107        Expr::ErrorProp(inner) => {
108            collect_expr_bindings(inner, local_slots, next_slot);
109        }
110        Expr::Constructor(_, Some(inner)) => {
111            collect_expr_bindings(inner, local_slots, next_slot);
112        }
113        Expr::List(elements) => {
114            for elem in elements {
115                collect_expr_bindings(elem, local_slots, next_slot);
116            }
117        }
118        Expr::Tuple(items) => {
119            for item in items {
120                collect_expr_bindings(item, local_slots, next_slot);
121            }
122        }
123        Expr::MapLiteral(entries) => {
124            for (key, value) in entries {
125                collect_expr_bindings(key, local_slots, next_slot);
126                collect_expr_bindings(value, local_slots, next_slot);
127            }
128        }
129        Expr::InterpolatedStr(parts) => {
130            for part in parts {
131                if let StrPart::Parsed(e) = part {
132                    collect_expr_bindings(e, local_slots, next_slot);
133                }
134            }
135        }
136        Expr::RecordCreate { fields, .. } => {
137            for (_, expr) in fields {
138                collect_expr_bindings(expr, local_slots, next_slot);
139            }
140        }
141        Expr::RecordUpdate { base, updates, .. } => {
142            collect_expr_bindings(base, local_slots, next_slot);
143            for (_, expr) in updates {
144                collect_expr_bindings(expr, local_slots, next_slot);
145            }
146        }
147        Expr::Attr(obj, _) => {
148            collect_expr_bindings(obj, local_slots, next_slot);
149        }
150        Expr::TailCall(boxed) => {
151            for arg in &boxed.1 {
152                collect_expr_bindings(arg, local_slots, next_slot);
153            }
154        }
155        // Leaves — no bindings to collect
156        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
157    }
158}
159
160/// Assign slots for names introduced by a pattern.
161fn collect_pattern_bindings(
162    pattern: &Pattern,
163    local_slots: &mut HashMap<String, u16>,
164    next_slot: &mut u16,
165) {
166    match pattern {
167        Pattern::Ident(name) => {
168            if !local_slots.contains_key(name) {
169                local_slots.insert(name.clone(), *next_slot);
170                *next_slot += 1;
171            }
172        }
173        Pattern::Cons(head, tail) => {
174            for name in [head, tail] {
175                if name != "_" && !local_slots.contains_key(name) {
176                    local_slots.insert(name.clone(), *next_slot);
177                    *next_slot += 1;
178                }
179            }
180        }
181        Pattern::Constructor(_, bindings) => {
182            for name in bindings {
183                if name != "_" && !local_slots.contains_key(name) {
184                    local_slots.insert(name.clone(), *next_slot);
185                    *next_slot += 1;
186                }
187            }
188        }
189        Pattern::Tuple(items) => {
190            for item in items {
191                collect_pattern_bindings(item, local_slots, next_slot);
192            }
193        }
194        Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
195    }
196}
197
198/// Resolve `Expr::Ident` → `Expr::Resolved` for locals in an expression.
199fn resolve_expr(expr: &mut Expr, local_slots: &HashMap<String, u16>) {
200    match expr {
201        Expr::Ident(name) => {
202            if let Some(&slot) = local_slots.get(name) {
203                *expr = Expr::Resolved(slot);
204            }
205            // else: global/namespace — leave as Ident for HashMap fallback
206        }
207        Expr::Resolved(_) | Expr::Literal(_) => {}
208        Expr::Attr(obj, _) => {
209            resolve_expr(obj, local_slots);
210        }
211        Expr::FnCall(func, args) => {
212            resolve_expr(func, local_slots);
213            for arg in args {
214                resolve_expr(arg, local_slots);
215            }
216        }
217        Expr::BinOp(_, left, right) => {
218            resolve_expr(left, local_slots);
219            resolve_expr(right, local_slots);
220        }
221        Expr::Match { subject, arms, .. } => {
222            resolve_expr(subject, local_slots);
223            for arm in arms {
224                resolve_expr(&mut arm.body, local_slots);
225            }
226        }
227        Expr::Constructor(_, Some(inner)) => {
228            resolve_expr(inner, local_slots);
229        }
230        Expr::Constructor(_, None) => {}
231        Expr::ErrorProp(inner) => {
232            resolve_expr(inner, local_slots);
233        }
234        Expr::InterpolatedStr(parts) => {
235            for part in parts {
236                if let StrPart::Parsed(e) = part {
237                    resolve_expr(e, local_slots);
238                }
239            }
240        }
241        Expr::List(elements) => {
242            for elem in elements {
243                resolve_expr(elem, local_slots);
244            }
245        }
246        Expr::Tuple(items) => {
247            for item in items {
248                resolve_expr(item, local_slots);
249            }
250        }
251        Expr::MapLiteral(entries) => {
252            for (key, value) in entries {
253                resolve_expr(key, local_slots);
254                resolve_expr(value, local_slots);
255            }
256        }
257        Expr::RecordCreate { fields, .. } => {
258            for (_, expr) in fields {
259                resolve_expr(expr, local_slots);
260            }
261        }
262        Expr::RecordUpdate { base, updates, .. } => {
263            resolve_expr(base, local_slots);
264            for (_, expr) in updates {
265                resolve_expr(expr, local_slots);
266            }
267        }
268        Expr::TailCall(boxed) => {
269            for arg in &mut boxed.1 {
270                resolve_expr(arg, local_slots);
271            }
272        }
273    }
274}
275
276/// Resolve expressions inside statements.
277fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
278    for stmt in stmts {
279        match stmt {
280            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
281                resolve_expr(expr, local_slots);
282            }
283        }
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn resolves_param_to_slot() {
293        let mut fd = FnDef {
294            name: "add".to_string(),
295            line: 1,
296            params: vec![
297                ("a".to_string(), "Int".to_string()),
298                ("b".to_string(), "Int".to_string()),
299            ],
300            return_type: "Int".to_string(),
301            effects: vec![],
302            desc: None,
303            body: Rc::new(FnBody::Expr(Expr::BinOp(
304                BinOp::Add,
305                Box::new(Expr::Ident("a".to_string())),
306                Box::new(Expr::Ident("b".to_string())),
307            ))),
308            resolution: None,
309        };
310        resolve_fn(&mut fd);
311        let res = fd.resolution.as_ref().unwrap();
312        assert_eq!(res.local_slots["a"], 0);
313        assert_eq!(res.local_slots["b"], 1);
314        assert_eq!(res.local_count, 2);
315
316        match fd.body.as_ref() {
317            FnBody::Expr(Expr::BinOp(_, left, right)) => {
318                assert_eq!(**left, Expr::Resolved(0));
319                assert_eq!(**right, Expr::Resolved(1));
320            }
321            other => panic!("unexpected body: {:?}", other),
322        }
323    }
324
325    #[test]
326    fn leaves_globals_as_ident() {
327        let mut fd = FnDef {
328            name: "f".to_string(),
329            line: 1,
330            params: vec![("x".to_string(), "Int".to_string())],
331            return_type: "Int".to_string(),
332            effects: vec![],
333            desc: None,
334            body: Rc::new(FnBody::Expr(Expr::FnCall(
335                Box::new(Expr::Ident("Console".to_string())),
336                vec![Expr::Ident("x".to_string())],
337            ))),
338            resolution: None,
339        };
340        resolve_fn(&mut fd);
341        match fd.body.as_ref() {
342            FnBody::Expr(Expr::FnCall(func, args)) => {
343                // Console stays as Ident (global)
344                assert_eq!(**func, Expr::Ident("Console".to_string()));
345                // x is resolved to slot 0
346                assert_eq!(args[0], Expr::Resolved(0));
347            }
348            other => panic!("unexpected body: {:?}", other),
349        }
350    }
351
352    #[test]
353    fn resolves_val_in_block_body() {
354        let mut fd = FnDef {
355            name: "f".to_string(),
356            line: 1,
357            params: vec![("x".to_string(), "Int".to_string())],
358            return_type: "Int".to_string(),
359            effects: vec![],
360            desc: None,
361            body: Rc::new(FnBody::Block(vec![
362                Stmt::Binding(
363                    "y".to_string(),
364                    None,
365                    Expr::BinOp(
366                        BinOp::Add,
367                        Box::new(Expr::Ident("x".to_string())),
368                        Box::new(Expr::Literal(Literal::Int(1))),
369                    ),
370                ),
371                Stmt::Expr(Expr::Ident("y".to_string())),
372            ])),
373            resolution: None,
374        };
375        resolve_fn(&mut fd);
376        let res = fd.resolution.as_ref().unwrap();
377        assert_eq!(res.local_slots["x"], 0);
378        assert_eq!(res.local_slots["y"], 1);
379        assert_eq!(res.local_count, 2);
380
381        match fd.body.as_ref() {
382            FnBody::Block(stmts) => {
383                // val y = x + 1  →  val y = Resolved(0,0) + 1
384                match &stmts[0] {
385                    Stmt::Binding(_, _, Expr::BinOp(_, left, _)) => {
386                        assert_eq!(**left, Expr::Resolved(0));
387                    }
388                    other => panic!("unexpected stmt: {:?}", other),
389                }
390                // y  →  Resolved(0,1)
391                match &stmts[1] {
392                    Stmt::Expr(Expr::Resolved(1)) => {}
393                    other => panic!("unexpected stmt: {:?}", other),
394                }
395            }
396            other => panic!("unexpected body: {:?}", other),
397        }
398    }
399
400    #[test]
401    fn resolves_match_pattern_bindings() {
402        // fn f(x: Int) -> Int / match x: Result.Ok(v) -> v, _ -> 0
403        let mut fd = FnDef {
404            name: "f".to_string(),
405            line: 1,
406            params: vec![("x".to_string(), "Int".to_string())],
407            return_type: "Int".to_string(),
408            effects: vec![],
409            desc: None,
410            body: Rc::new(FnBody::Expr(Expr::Match {
411                subject: Box::new(Expr::Ident("x".to_string())),
412                arms: vec![
413                    MatchArm {
414                        pattern: Pattern::Constructor(
415                            "Result.Ok".to_string(),
416                            vec!["v".to_string()],
417                        ),
418                        body: Box::new(Expr::Ident("v".to_string())),
419                    },
420                    MatchArm {
421                        pattern: Pattern::Wildcard,
422                        body: Box::new(Expr::Literal(Literal::Int(0))),
423                    },
424                ],
425                line: 1,
426            })),
427            resolution: None,
428        };
429        resolve_fn(&mut fd);
430        let res = fd.resolution.as_ref().unwrap();
431        // x=0, v=1
432        assert_eq!(res.local_slots["v"], 1);
433
434        match fd.body.as_ref() {
435            FnBody::Expr(Expr::Match { arms, .. }) => {
436                // v in arm body should be Resolved(0, 1)
437                assert_eq!(*arms[0].body, Expr::Resolved(1));
438            }
439            other => panic!("unexpected body: {:?}", other),
440        }
441    }
442}