Skip to main content

aver/
resolver.rs

1/// Compile-time variable resolution pass.
2///
3/// After parsing and before interpretation, this pass walks each `FnDef` body
4/// and replaces `Expr::Ident(name)` with `Expr::Resolved(depth, slot)` for
5/// variables that are local to the function (parameters + bindings).
6///
7/// Global/namespace identifiers are left as `Expr::Ident` — the VM
8/// falls back to HashMap lookup for those.
9///
10/// Only top-level `FnDef` bodies are resolved. Top-level `Stmt` items (globals,
11/// REPL) are not touched.
12use std::collections::HashMap;
13use std::sync::Arc as Rc;
14
15use crate::ast::*;
16
17/// Run the resolver on all top-level function definitions. Stops after
18/// slot resolution — last-use ownership annotation is its own pipeline
19/// stage (`ir::pipeline::last_use`) so the two analyses are individually
20/// observable and skippable.
21pub fn resolve_program(items: &mut [TopLevel]) {
22    for item in items.iter_mut() {
23        if let TopLevel::FnDef(fd) = item {
24            resolve_fn(fd);
25        }
26    }
27}
28
29/// Resolve a single function definition.
30fn resolve_fn(fd: &mut FnDef) {
31    let mut local_slots: HashMap<String, u16> = HashMap::new();
32    let mut next_slot: u16 = 0;
33
34    // Params get slots 0..N-1
35    for (param_name, _) in &fd.params {
36        local_slots.insert(param_name.clone(), next_slot);
37        next_slot += 1;
38    }
39
40    // Scan body for val/var bindings to pre-allocate slots
41    collect_binding_slots(fd.body.stmts(), &mut local_slots, &mut next_slot);
42
43    // Resolve expressions in the body
44    let mut body = fd.body.as_ref().clone();
45    resolve_stmts(body.stmts_mut(), &local_slots);
46    fd.body = Rc::new(body);
47
48    fd.resolution = Some(FnResolution {
49        local_count: next_slot,
50        local_slots: Rc::new(local_slots),
51    });
52}
53
54/// Collect all binding names from a statement list and assign slots.
55/// This handles match arms recursively (pattern bindings get slots too).
56fn collect_binding_slots(
57    stmts: &[Stmt],
58    local_slots: &mut HashMap<String, u16>,
59    next_slot: &mut u16,
60) {
61    for stmt in stmts {
62        match stmt {
63            Stmt::Binding(name, _, expr) => {
64                if !local_slots.contains_key(name) {
65                    local_slots.insert(name.clone(), *next_slot);
66                    *next_slot += 1;
67                }
68                collect_expr_bindings(expr, local_slots, next_slot);
69            }
70            Stmt::Expr(expr) => {
71                collect_expr_bindings(expr, local_slots, next_slot);
72            }
73        }
74    }
75}
76
77/// Collect pattern bindings from match expressions inside an expression tree.
78fn collect_expr_bindings(
79    expr: &Spanned<Expr>,
80    local_slots: &mut HashMap<String, u16>,
81    next_slot: &mut u16,
82) {
83    match &expr.node {
84        Expr::Match { subject, arms } => {
85            collect_expr_bindings(subject, local_slots, next_slot);
86            for arm in arms {
87                collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
88                collect_expr_bindings(&arm.body, local_slots, next_slot);
89            }
90        }
91        Expr::BinOp(_, left, right) => {
92            collect_expr_bindings(left, local_slots, next_slot);
93            collect_expr_bindings(right, local_slots, next_slot);
94        }
95        Expr::FnCall(func, args) => {
96            collect_expr_bindings(func, local_slots, next_slot);
97            for arg in args {
98                collect_expr_bindings(arg, local_slots, next_slot);
99            }
100        }
101        Expr::ErrorProp(inner) => {
102            collect_expr_bindings(inner, local_slots, next_slot);
103        }
104        Expr::Constructor(_, Some(inner)) => {
105            collect_expr_bindings(inner, local_slots, next_slot);
106        }
107        Expr::List(elements) => {
108            for elem in elements {
109                collect_expr_bindings(elem, local_slots, next_slot);
110            }
111        }
112        Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
113            for item in items {
114                collect_expr_bindings(item, local_slots, next_slot);
115            }
116        }
117        Expr::MapLiteral(entries) => {
118            for (key, value) in entries {
119                collect_expr_bindings(key, local_slots, next_slot);
120                collect_expr_bindings(value, local_slots, next_slot);
121            }
122        }
123        Expr::InterpolatedStr(parts) => {
124            for part in parts {
125                if let StrPart::Parsed(e) = part {
126                    collect_expr_bindings(e, local_slots, next_slot);
127                }
128            }
129        }
130        Expr::RecordCreate { fields, .. } => {
131            for (_, expr) in fields {
132                collect_expr_bindings(expr, local_slots, next_slot);
133            }
134        }
135        Expr::RecordUpdate { base, updates, .. } => {
136            collect_expr_bindings(base, local_slots, next_slot);
137            for (_, expr) in updates {
138                collect_expr_bindings(expr, local_slots, next_slot);
139            }
140        }
141        Expr::Attr(obj, _) => {
142            collect_expr_bindings(obj, local_slots, next_slot);
143        }
144        Expr::TailCall(boxed) => {
145            for arg in &boxed.args {
146                collect_expr_bindings(arg, local_slots, next_slot);
147            }
148        }
149        // Leaves — no bindings to collect
150        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
151    }
152}
153
154/// Assign slots for names introduced by a pattern.
155fn collect_pattern_bindings(
156    pattern: &Pattern,
157    local_slots: &mut HashMap<String, u16>,
158    next_slot: &mut u16,
159) {
160    match pattern {
161        Pattern::Ident(name) => {
162            if !local_slots.contains_key(name) {
163                local_slots.insert(name.clone(), *next_slot);
164                *next_slot += 1;
165            }
166        }
167        Pattern::Cons(head, tail) => {
168            for name in [head, tail] {
169                if name != "_" && !local_slots.contains_key(name) {
170                    local_slots.insert(name.clone(), *next_slot);
171                    *next_slot += 1;
172                }
173            }
174        }
175        Pattern::Constructor(_, bindings) => {
176            for name in bindings {
177                if name != "_" && !local_slots.contains_key(name) {
178                    local_slots.insert(name.clone(), *next_slot);
179                    *next_slot += 1;
180                }
181            }
182        }
183        Pattern::Tuple(items) => {
184            for item in items {
185                collect_pattern_bindings(item, local_slots, next_slot);
186            }
187        }
188        Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
189    }
190}
191
192/// Resolve `Expr::Ident` → `Expr::Resolved` for locals in an expression.
193fn resolve_expr(expr: &mut Spanned<Expr>, local_slots: &HashMap<String, u16>) {
194    match &mut expr.node {
195        Expr::Ident(name) => {
196            if let Some(&slot) = local_slots.get(name) {
197                expr.node = Expr::Resolved {
198                    slot,
199                    name: name.clone(),
200                    last_use: AnnotBool(false),
201                };
202            }
203            // else: global/namespace — leave as Ident for HashMap fallback
204        }
205        Expr::Resolved { .. } | Expr::Literal(_) => {}
206        Expr::Attr(obj, _) => {
207            resolve_expr(obj, local_slots);
208        }
209        Expr::FnCall(func, args) => {
210            resolve_expr(func, local_slots);
211            for arg in args {
212                resolve_expr(arg, local_slots);
213            }
214        }
215        Expr::BinOp(_, left, right) => {
216            resolve_expr(left, local_slots);
217            resolve_expr(right, local_slots);
218        }
219        Expr::Match { subject, arms } => {
220            resolve_expr(subject, local_slots);
221            for arm in arms {
222                resolve_expr(&mut arm.body, local_slots);
223            }
224        }
225        Expr::Constructor(_, Some(inner)) => {
226            resolve_expr(inner, local_slots);
227        }
228        Expr::Constructor(_, None) => {}
229        Expr::ErrorProp(inner) => {
230            resolve_expr(inner, local_slots);
231        }
232        Expr::InterpolatedStr(parts) => {
233            for part in parts {
234                if let StrPart::Parsed(e) = part {
235                    resolve_expr(e, local_slots);
236                }
237            }
238        }
239        Expr::List(elements) => {
240            for elem in elements {
241                resolve_expr(elem, local_slots);
242            }
243        }
244        Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
245            for item in items {
246                resolve_expr(item, local_slots);
247            }
248        }
249        Expr::MapLiteral(entries) => {
250            for (key, value) in entries {
251                resolve_expr(key, local_slots);
252                resolve_expr(value, local_slots);
253            }
254        }
255        Expr::RecordCreate { fields, .. } => {
256            for (_, expr) in fields {
257                resolve_expr(expr, local_slots);
258            }
259        }
260        Expr::RecordUpdate { base, updates, .. } => {
261            resolve_expr(base, local_slots);
262            for (_, expr) in updates {
263                resolve_expr(expr, local_slots);
264            }
265        }
266        Expr::TailCall(boxed) => {
267            for arg in &mut boxed.args {
268                resolve_expr(arg, local_slots);
269            }
270        }
271    }
272}
273
274/// Resolve expressions inside statements.
275fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
276    for stmt in stmts {
277        match stmt {
278            Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
279                resolve_expr(expr, local_slots);
280            }
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn resolves_param_to_slot() {
291        let mut fd = FnDef {
292            name: "add".to_string(),
293            line: 1,
294            params: vec![
295                ("a".to_string(), "Int".to_string()),
296                ("b".to_string(), "Int".to_string()),
297            ],
298            return_type: "Int".to_string(),
299            effects: vec![],
300            desc: None,
301            body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::BinOp(
302                BinOp::Add,
303                Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
304                Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
305            )))),
306            resolution: None,
307        };
308        resolve_fn(&mut fd);
309        let res = fd.resolution.as_ref().unwrap();
310        assert_eq!(res.local_slots["a"], 0);
311        assert_eq!(res.local_slots["b"], 1);
312        assert_eq!(res.local_count, 2);
313
314        match fd.body.tail_expr() {
315            Some(Spanned {
316                node: Expr::BinOp(_, left, right),
317                ..
318            }) => {
319                assert_eq!(
320                    left.node,
321                    Expr::Resolved {
322                        slot: 0,
323                        name: "a".to_string(),
324                        last_use: AnnotBool(false)
325                    }
326                );
327                assert_eq!(
328                    right.node,
329                    Expr::Resolved {
330                        slot: 1,
331                        name: "b".to_string(),
332                        last_use: AnnotBool(false)
333                    }
334                );
335            }
336            other => panic!("unexpected body: {:?}", other),
337        }
338    }
339
340    #[test]
341    fn leaves_globals_as_ident() {
342        let mut fd = FnDef {
343            name: "f".to_string(),
344            line: 1,
345            params: vec![("x".to_string(), "Int".to_string())],
346            return_type: "Int".to_string(),
347            effects: vec![],
348            desc: None,
349            body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::FnCall(
350                Box::new(Spanned::bare(Expr::Ident("Console".to_string()))),
351                vec![Spanned::bare(Expr::Ident("x".to_string()))],
352            )))),
353            resolution: None,
354        };
355        resolve_fn(&mut fd);
356        match fd.body.tail_expr() {
357            Some(Spanned {
358                node: Expr::FnCall(func, args),
359                ..
360            }) => {
361                assert_eq!(func.node, Expr::Ident("Console".to_string()));
362                assert_eq!(
363                    args[0].node,
364                    Expr::Resolved {
365                        slot: 0,
366                        name: "x".to_string(),
367                        last_use: AnnotBool(false)
368                    }
369                );
370            }
371            other => panic!("unexpected body: {:?}", other),
372        }
373    }
374
375    #[test]
376    fn resolves_val_in_block_body() {
377        let mut fd = FnDef {
378            name: "f".to_string(),
379            line: 1,
380            params: vec![("x".to_string(), "Int".to_string())],
381            return_type: "Int".to_string(),
382            effects: vec![],
383            desc: None,
384            body: Rc::new(FnBody::Block(vec![
385                Stmt::Binding(
386                    "y".to_string(),
387                    None,
388                    Spanned::bare(Expr::BinOp(
389                        BinOp::Add,
390                        Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
391                        Box::new(Spanned::bare(Expr::Literal(Literal::Int(1)))),
392                    )),
393                ),
394                Stmt::Expr(Spanned::bare(Expr::Ident("y".to_string()))),
395            ])),
396            resolution: None,
397        };
398        resolve_fn(&mut fd);
399        let res = fd.resolution.as_ref().unwrap();
400        assert_eq!(res.local_slots["x"], 0);
401        assert_eq!(res.local_slots["y"], 1);
402        assert_eq!(res.local_count, 2);
403
404        let stmts = fd.body.stmts();
405        // val y = x + 1  →  val y = Resolved(0,0) + 1
406        match &stmts[0] {
407            Stmt::Binding(
408                _,
409                _,
410                Spanned {
411                    node: Expr::BinOp(_, left, _),
412                    ..
413                },
414            ) => {
415                assert_eq!(
416                    left.node,
417                    Expr::Resolved {
418                        slot: 0,
419                        name: "x".to_string(),
420                        last_use: AnnotBool(false)
421                    }
422                );
423            }
424            other => panic!("unexpected stmt: {:?}", other),
425        }
426        // y  →  Resolved(0,1)
427        match &stmts[1] {
428            Stmt::Expr(Spanned {
429                node: Expr::Resolved { slot: 1, .. },
430                ..
431            }) => {}
432            other => panic!("unexpected stmt: {:?}", other),
433        }
434    }
435
436    #[test]
437    fn resolves_match_pattern_bindings() {
438        // fn f(x: Int) -> Int / match x: Result.Ok(v) -> v, _ -> 0
439        let mut fd = FnDef {
440            name: "f".to_string(),
441            line: 1,
442            params: vec![("x".to_string(), "Int".to_string())],
443            return_type: "Int".to_string(),
444            effects: vec![],
445            desc: None,
446            body: Rc::new(FnBody::from_expr(Spanned::new(
447                Expr::Match {
448                    subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
449                    arms: vec![
450                        MatchArm {
451                            pattern: Pattern::Constructor(
452                                "Result.Ok".to_string(),
453                                vec!["v".to_string()],
454                            ),
455                            body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
456                        },
457                        MatchArm {
458                            pattern: Pattern::Wildcard,
459                            body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
460                        },
461                    ],
462                },
463                1,
464            ))),
465            resolution: None,
466        };
467        resolve_fn(&mut fd);
468        let res = fd.resolution.as_ref().unwrap();
469        // x=0, v=1
470        assert_eq!(res.local_slots["v"], 1);
471
472        match fd.body.tail_expr() {
473            Some(Spanned {
474                node: Expr::Match { arms, .. },
475                ..
476            }) => {
477                assert_eq!(
478                    arms[0].body.node,
479                    Expr::Resolved {
480                        slot: 1,
481                        name: "v".to_string(),
482                        last_use: AnnotBool(false)
483                    }
484                );
485            }
486            other => panic!("unexpected body: {:?}", other),
487        }
488    }
489
490    #[test]
491    fn resolves_match_pattern_bindings_inside_binding_initializer() {
492        let mut fd = FnDef {
493            name: "f".to_string(),
494            line: 1,
495            params: vec![("x".to_string(), "Int".to_string())],
496            return_type: "Int".to_string(),
497            effects: vec![],
498            desc: None,
499            body: Rc::new(FnBody::Block(vec![
500                Stmt::Binding(
501                    "result".to_string(),
502                    None,
503                    Spanned::bare(Expr::Match {
504                        subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
505                        arms: vec![
506                            MatchArm {
507                                pattern: Pattern::Constructor(
508                                    "Option.Some".to_string(),
509                                    vec!["v".to_string()],
510                                ),
511                                body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
512                            },
513                            MatchArm {
514                                pattern: Pattern::Wildcard,
515                                body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
516                            },
517                        ],
518                    }),
519                ),
520                Stmt::Expr(Spanned::bare(Expr::Ident("result".to_string()))),
521            ])),
522            resolution: None,
523        };
524
525        resolve_fn(&mut fd);
526        let res = fd.resolution.as_ref().unwrap();
527        assert_eq!(res.local_slots["x"], 0);
528        assert_eq!(res.local_slots["result"], 1);
529        assert_eq!(res.local_slots["v"], 2);
530
531        let stmts = fd.body.stmts();
532        match &stmts[0] {
533            Stmt::Binding(
534                _,
535                _,
536                Spanned {
537                    node: Expr::Match { arms, .. },
538                    ..
539                },
540            ) => {
541                assert_eq!(
542                    arms[0].body.node,
543                    Expr::Resolved {
544                        slot: 2,
545                        name: "v".to_string(),
546                        last_use: AnnotBool(false)
547                    }
548                );
549            }
550            other => panic!("unexpected stmt: {:?}", other),
551        }
552
553        match &stmts[1] {
554            Stmt::Expr(Spanned {
555                node: Expr::Resolved { slot: 1, .. },
556                ..
557            }) => {}
558            other => panic!("unexpected stmt: {:?}", other),
559        }
560    }
561}