Skip to main content

aver/
resolver.rs

1/// Compile-time variable resolution pass.
2///
3/// After parsing and before interpretation, this pass walks each `FnDef` body
4/// and replaces `Expr::Ident(name)` with `Expr::Resolved(depth, slot)` for
5/// variables that are local to the function (parameters + bindings).
6///
7/// Global/namespace identifiers are left as `Expr::Ident` — the interpreter
8/// falls back to HashMap lookup for those.
9///
10/// Only top-level `FnDef` bodies are resolved. Top-level `Stmt` items (globals,
11/// REPL) are not touched.
12use std::collections::HashMap;
13use std::rc::Rc;
14
15use crate::ast::*;
16
17/// Run the resolver on all top-level function definitions.
18pub fn resolve_program(items: &mut Vec<TopLevel>) {
19    for item in items.iter_mut() {
20        if let TopLevel::FnDef(fd) = item {
21            resolve_fn(fd);
22        }
23    }
24}
25
26/// Resolve a single function definition.
27fn resolve_fn(fd: &mut FnDef) {
28    let mut local_slots: HashMap<String, u16> = HashMap::new();
29    let mut next_slot: u16 = 0;
30
31    // Params get slots 0..N-1
32    for (param_name, _) in &fd.params {
33        local_slots.insert(param_name.clone(), next_slot);
34        next_slot += 1;
35    }
36
37    // Scan body for val/var bindings to pre-allocate slots
38    match fd.body.as_ref() {
39        FnBody::Expr(expr) => {
40            collect_expr_bindings(expr, &mut local_slots, &mut next_slot);
41        }
42        FnBody::Block(stmts) => {
43            collect_binding_slots(stmts, &mut local_slots, &mut next_slot);
44        }
45    }
46
47    // Resolve expressions in the body
48    let mut body = fd.body.as_ref().clone();
49    match &mut body {
50        FnBody::Expr(expr) => {
51            resolve_expr(expr, &local_slots);
52        }
53        FnBody::Block(stmts) => {
54            resolve_stmts(stmts, &local_slots);
55        }
56    }
57    fd.body = Rc::new(body);
58
59    fd.resolution = Some(FnResolution {
60        local_count: next_slot,
61        local_slots,
62    });
63}
64
65/// Collect all binding names from a statement list and assign slots.
66/// This handles match arms recursively (pattern bindings get slots too).
67fn collect_binding_slots(
68    stmts: &[Stmt],
69    local_slots: &mut HashMap<String, u16>,
70    next_slot: &mut u16,
71) {
72    for stmt in stmts {
73        match stmt {
74            Stmt::Binding(name, _, _) => {
75                if !local_slots.contains_key(name) {
76                    local_slots.insert(name.clone(), *next_slot);
77                    *next_slot += 1;
78                }
79            }
80            Stmt::Expr(expr) => {
81                collect_expr_bindings(expr, local_slots, next_slot);
82            }
83        }
84    }
85}
86
87/// Collect pattern bindings from match expressions inside an expression tree.
88fn collect_expr_bindings(expr: &Expr, local_slots: &mut HashMap<String, u16>, next_slot: &mut u16) {
89    match expr {
90        Expr::Match { subject, arms, .. } => {
91            collect_expr_bindings(subject, local_slots, next_slot);
92            for arm in arms {
93                collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
94                collect_expr_bindings(&arm.body, local_slots, next_slot);
95            }
96        }
97        Expr::BinOp(_, left, right) => {
98            collect_expr_bindings(left, local_slots, next_slot);
99            collect_expr_bindings(right, local_slots, next_slot);
100        }
101        Expr::FnCall(func, args) => {
102            collect_expr_bindings(func, local_slots, next_slot);
103            for arg in args {
104                collect_expr_bindings(arg, local_slots, next_slot);
105            }
106        }
107        Expr::Pipe(left, right) => {
108            collect_expr_bindings(left, local_slots, next_slot);
109            collect_expr_bindings(right, local_slots, next_slot);
110        }
111        Expr::ErrorProp(inner) => {
112            collect_expr_bindings(inner, local_slots, next_slot);
113        }
114        Expr::Constructor(_, Some(inner)) => {
115            collect_expr_bindings(inner, local_slots, next_slot);
116        }
117        Expr::List(elements) => {
118            for elem in elements {
119                collect_expr_bindings(elem, local_slots, next_slot);
120            }
121        }
122        Expr::Tuple(items) => {
123            for item in items {
124                collect_expr_bindings(item, local_slots, next_slot);
125            }
126        }
127        Expr::MapLiteral(entries) => {
128            for (key, value) in entries {
129                collect_expr_bindings(key, local_slots, next_slot);
130                collect_expr_bindings(value, local_slots, next_slot);
131            }
132        }
133        Expr::InterpolatedStr(parts) => {
134            for part in parts {
135                if let StrPart::Parsed(e) = part {
136                    collect_expr_bindings(e, local_slots, next_slot);
137                }
138            }
139        }
140        Expr::RecordCreate { fields, .. } => {
141            for (_, expr) in fields {
142                collect_expr_bindings(expr, local_slots, next_slot);
143            }
144        }
145        Expr::RecordUpdate { base, updates, .. } => {
146            collect_expr_bindings(base, local_slots, next_slot);
147            for (_, expr) in updates {
148                collect_expr_bindings(expr, local_slots, next_slot);
149            }
150        }
151        Expr::Attr(obj, _) => {
152            collect_expr_bindings(obj, local_slots, next_slot);
153        }
154        Expr::TailCall(boxed) => {
155            for arg in &boxed.1 {
156                collect_expr_bindings(arg, local_slots, next_slot);
157            }
158        }
159        // Leaves — no bindings to collect
160        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
161    }
162}
163
164/// Assign slots for names introduced by a pattern.
165fn collect_pattern_bindings(
166    pattern: &Pattern,
167    local_slots: &mut HashMap<String, u16>,
168    next_slot: &mut u16,
169) {
170    match pattern {
171        Pattern::Ident(name) => {
172            if !local_slots.contains_key(name) {
173                local_slots.insert(name.clone(), *next_slot);
174                *next_slot += 1;
175            }
176        }
177        Pattern::Cons(head, tail) => {
178            for name in [head, tail] {
179                if name != "_" && !local_slots.contains_key(name) {
180                    local_slots.insert(name.clone(), *next_slot);
181                    *next_slot += 1;
182                }
183            }
184        }
185        Pattern::Constructor(_, bindings) => {
186            for name in bindings {
187                if name != "_" && !local_slots.contains_key(name) {
188                    local_slots.insert(name.clone(), *next_slot);
189                    *next_slot += 1;
190                }
191            }
192        }
193        Pattern::Tuple(items) => {
194            for item in items {
195                collect_pattern_bindings(item, local_slots, next_slot);
196            }
197        }
198        Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
199    }
200}
201
202/// Resolve `Expr::Ident` → `Expr::Resolved` for locals in an expression.
203fn resolve_expr(expr: &mut Expr, local_slots: &HashMap<String, u16>) {
204    match expr {
205        Expr::Ident(name) => {
206            if let Some(&slot) = local_slots.get(name) {
207                *expr = Expr::Resolved(slot);
208            }
209            // else: global/namespace — leave as Ident for HashMap fallback
210        }
211        Expr::Resolved(_) | Expr::Literal(_) => {}
212        Expr::Attr(obj, _) => {
213            resolve_expr(obj, local_slots);
214        }
215        Expr::FnCall(func, args) => {
216            resolve_expr(func, local_slots);
217            for arg in args {
218                resolve_expr(arg, local_slots);
219            }
220        }
221        Expr::BinOp(_, left, right) => {
222            resolve_expr(left, local_slots);
223            resolve_expr(right, local_slots);
224        }
225        Expr::Match { subject, arms, .. } => {
226            resolve_expr(subject, local_slots);
227            for arm in arms {
228                resolve_expr(&mut arm.body, local_slots);
229            }
230        }
231        Expr::Pipe(left, right) => {
232            resolve_expr(left, local_slots);
233            resolve_expr(right, local_slots);
234        }
235        Expr::Constructor(_, Some(inner)) => {
236            resolve_expr(inner, local_slots);
237        }
238        Expr::Constructor(_, None) => {}
239        Expr::ErrorProp(inner) => {
240            resolve_expr(inner, local_slots);
241        }
242        Expr::InterpolatedStr(parts) => {
243            for part in parts {
244                if let StrPart::Parsed(e) = part {
245                    resolve_expr(e, local_slots);
246                }
247            }
248        }
249        Expr::List(elements) => {
250            for elem in elements {
251                resolve_expr(elem, local_slots);
252            }
253        }
254        Expr::Tuple(items) => {
255            for item in items {
256                resolve_expr(item, local_slots);
257            }
258        }
259        Expr::MapLiteral(entries) => {
260            for (key, value) in entries {
261                resolve_expr(key, local_slots);
262                resolve_expr(value, local_slots);
263            }
264        }
265        Expr::RecordCreate { fields, .. } => {
266            for (_, expr) in fields {
267                resolve_expr(expr, local_slots);
268            }
269        }
270        Expr::RecordUpdate { base, updates, .. } => {
271            resolve_expr(base, local_slots);
272            for (_, expr) in updates {
273                resolve_expr(expr, local_slots);
274            }
275        }
276        Expr::TailCall(boxed) => {
277            for arg in &mut boxed.1 {
278                resolve_expr(arg, local_slots);
279            }
280        }
281    }
282}
283
284/// Resolve expressions inside statements.
285fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
286    for stmt in stmts {
287        match stmt {
288            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
289                resolve_expr(expr, local_slots);
290            }
291        }
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn resolves_param_to_slot() {
301        let mut fd = FnDef {
302            name: "add".to_string(),
303            line: 1,
304            params: vec![
305                ("a".to_string(), "Int".to_string()),
306                ("b".to_string(), "Int".to_string()),
307            ],
308            return_type: "Int".to_string(),
309            effects: vec![],
310            desc: None,
311            body: Rc::new(FnBody::Expr(Expr::BinOp(
312                BinOp::Add,
313                Box::new(Expr::Ident("a".to_string())),
314                Box::new(Expr::Ident("b".to_string())),
315            ))),
316            resolution: None,
317        };
318        resolve_fn(&mut fd);
319        let res = fd.resolution.as_ref().unwrap();
320        assert_eq!(res.local_slots["a"], 0);
321        assert_eq!(res.local_slots["b"], 1);
322        assert_eq!(res.local_count, 2);
323
324        match fd.body.as_ref() {
325            FnBody::Expr(Expr::BinOp(_, left, right)) => {
326                assert_eq!(**left, Expr::Resolved(0));
327                assert_eq!(**right, Expr::Resolved(1));
328            }
329            other => panic!("unexpected body: {:?}", other),
330        }
331    }
332
333    #[test]
334    fn leaves_globals_as_ident() {
335        let mut fd = FnDef {
336            name: "f".to_string(),
337            line: 1,
338            params: vec![("x".to_string(), "Int".to_string())],
339            return_type: "Int".to_string(),
340            effects: vec![],
341            desc: None,
342            body: Rc::new(FnBody::Expr(Expr::FnCall(
343                Box::new(Expr::Ident("Console".to_string())),
344                vec![Expr::Ident("x".to_string())],
345            ))),
346            resolution: None,
347        };
348        resolve_fn(&mut fd);
349        match fd.body.as_ref() {
350            FnBody::Expr(Expr::FnCall(func, args)) => {
351                // Console stays as Ident (global)
352                assert_eq!(**func, Expr::Ident("Console".to_string()));
353                // x is resolved to slot 0
354                assert_eq!(args[0], Expr::Resolved(0));
355            }
356            other => panic!("unexpected body: {:?}", other),
357        }
358    }
359
360    #[test]
361    fn resolves_val_in_block_body() {
362        let mut fd = FnDef {
363            name: "f".to_string(),
364            line: 1,
365            params: vec![("x".to_string(), "Int".to_string())],
366            return_type: "Int".to_string(),
367            effects: vec![],
368            desc: None,
369            body: Rc::new(FnBody::Block(vec![
370                Stmt::Binding(
371                    "y".to_string(),
372                    None,
373                    Expr::BinOp(
374                        BinOp::Add,
375                        Box::new(Expr::Ident("x".to_string())),
376                        Box::new(Expr::Literal(Literal::Int(1))),
377                    ),
378                ),
379                Stmt::Expr(Expr::Ident("y".to_string())),
380            ])),
381            resolution: None,
382        };
383        resolve_fn(&mut fd);
384        let res = fd.resolution.as_ref().unwrap();
385        assert_eq!(res.local_slots["x"], 0);
386        assert_eq!(res.local_slots["y"], 1);
387        assert_eq!(res.local_count, 2);
388
389        match fd.body.as_ref() {
390            FnBody::Block(stmts) => {
391                // val y = x + 1  →  val y = Resolved(0,0) + 1
392                match &stmts[0] {
393                    Stmt::Binding(_, _, Expr::BinOp(_, left, _)) => {
394                        assert_eq!(**left, Expr::Resolved(0));
395                    }
396                    other => panic!("unexpected stmt: {:?}", other),
397                }
398                // y  →  Resolved(0,1)
399                match &stmts[1] {
400                    Stmt::Expr(Expr::Resolved(1)) => {}
401                    other => panic!("unexpected stmt: {:?}", other),
402                }
403            }
404            other => panic!("unexpected body: {:?}", other),
405        }
406    }
407
408    #[test]
409    fn resolves_match_pattern_bindings() {
410        // fn f(x: Int) -> Int = match x: Result.Ok(v) -> v, _ -> 0
411        let mut fd = FnDef {
412            name: "f".to_string(),
413            line: 1,
414            params: vec![("x".to_string(), "Int".to_string())],
415            return_type: "Int".to_string(),
416            effects: vec![],
417            desc: None,
418            body: Rc::new(FnBody::Expr(Expr::Match {
419                subject: Box::new(Expr::Ident("x".to_string())),
420                arms: vec![
421                    MatchArm {
422                        pattern: Pattern::Constructor(
423                            "Result.Ok".to_string(),
424                            vec!["v".to_string()],
425                        ),
426                        body: Box::new(Expr::Ident("v".to_string())),
427                    },
428                    MatchArm {
429                        pattern: Pattern::Wildcard,
430                        body: Box::new(Expr::Literal(Literal::Int(0))),
431                    },
432                ],
433                line: 1,
434            })),
435            resolution: None,
436        };
437        resolve_fn(&mut fd);
438        let res = fd.resolution.as_ref().unwrap();
439        // x=0, v=1
440        assert_eq!(res.local_slots["v"], 1);
441
442        match fd.body.as_ref() {
443            FnBody::Expr(Expr::Match { arms, .. }) => {
444                // v in arm body should be Resolved(0, 1)
445                assert_eq!(*arms[0].body, Expr::Resolved(1));
446            }
447            other => panic!("unexpected body: {:?}", other),
448        }
449    }
450}