1use std::collections::HashMap;
13use std::rc::Rc;
14
15use crate::ast::*;
16
17pub fn resolve_program(items: &mut [TopLevel]) {
19 for item in items.iter_mut() {
20 if let TopLevel::FnDef(fd) = item {
21 resolve_fn(fd);
22 }
23 }
24}
25
26fn resolve_fn(fd: &mut FnDef) {
28 let mut local_slots: HashMap<String, u16> = HashMap::new();
29 let mut next_slot: u16 = 0;
30
31 for (param_name, _) in &fd.params {
33 local_slots.insert(param_name.clone(), next_slot);
34 next_slot += 1;
35 }
36
37 collect_binding_slots(fd.body.stmts(), &mut local_slots, &mut next_slot);
39
40 let mut body = fd.body.as_ref().clone();
42 resolve_stmts(body.stmts_mut(), &local_slots);
43 fd.body = Rc::new(body);
44
45 fd.resolution = Some(FnResolution {
46 local_count: next_slot,
47 local_slots,
48 });
49}
50
51fn collect_binding_slots(
54 stmts: &[Stmt],
55 local_slots: &mut HashMap<String, u16>,
56 next_slot: &mut u16,
57) {
58 for stmt in stmts {
59 match stmt {
60 Stmt::Binding(name, _, _) => {
61 if !local_slots.contains_key(name) {
62 local_slots.insert(name.clone(), *next_slot);
63 *next_slot += 1;
64 }
65 }
66 Stmt::Expr(expr) => {
67 collect_expr_bindings(expr, local_slots, next_slot);
68 }
69 }
70 }
71}
72
73fn collect_expr_bindings(expr: &Expr, local_slots: &mut HashMap<String, u16>, next_slot: &mut u16) {
75 match expr {
76 Expr::Match { subject, arms, .. } => {
77 collect_expr_bindings(subject, local_slots, next_slot);
78 for arm in arms {
79 collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
80 collect_expr_bindings(&arm.body, local_slots, next_slot);
81 }
82 }
83 Expr::BinOp(_, left, right) => {
84 collect_expr_bindings(left, local_slots, next_slot);
85 collect_expr_bindings(right, local_slots, next_slot);
86 }
87 Expr::FnCall(func, args) => {
88 collect_expr_bindings(func, local_slots, next_slot);
89 for arg in args {
90 collect_expr_bindings(arg, local_slots, next_slot);
91 }
92 }
93 Expr::ErrorProp(inner) => {
94 collect_expr_bindings(inner, local_slots, next_slot);
95 }
96 Expr::Constructor(_, Some(inner)) => {
97 collect_expr_bindings(inner, local_slots, next_slot);
98 }
99 Expr::List(elements) => {
100 for elem in elements {
101 collect_expr_bindings(elem, local_slots, next_slot);
102 }
103 }
104 Expr::Tuple(items) => {
105 for item in items {
106 collect_expr_bindings(item, local_slots, next_slot);
107 }
108 }
109 Expr::MapLiteral(entries) => {
110 for (key, value) in entries {
111 collect_expr_bindings(key, local_slots, next_slot);
112 collect_expr_bindings(value, local_slots, next_slot);
113 }
114 }
115 Expr::InterpolatedStr(parts) => {
116 for part in parts {
117 if let StrPart::Parsed(e) = part {
118 collect_expr_bindings(e, local_slots, next_slot);
119 }
120 }
121 }
122 Expr::RecordCreate { fields, .. } => {
123 for (_, expr) in fields {
124 collect_expr_bindings(expr, local_slots, next_slot);
125 }
126 }
127 Expr::RecordUpdate { base, updates, .. } => {
128 collect_expr_bindings(base, local_slots, next_slot);
129 for (_, expr) in updates {
130 collect_expr_bindings(expr, local_slots, next_slot);
131 }
132 }
133 Expr::Attr(obj, _) => {
134 collect_expr_bindings(obj, local_slots, next_slot);
135 }
136 Expr::TailCall(boxed) => {
137 for arg in &boxed.1 {
138 collect_expr_bindings(arg, local_slots, next_slot);
139 }
140 }
141 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
143 }
144}
145
146fn collect_pattern_bindings(
148 pattern: &Pattern,
149 local_slots: &mut HashMap<String, u16>,
150 next_slot: &mut u16,
151) {
152 match pattern {
153 Pattern::Ident(name) => {
154 if !local_slots.contains_key(name) {
155 local_slots.insert(name.clone(), *next_slot);
156 *next_slot += 1;
157 }
158 }
159 Pattern::Cons(head, tail) => {
160 for name in [head, tail] {
161 if name != "_" && !local_slots.contains_key(name) {
162 local_slots.insert(name.clone(), *next_slot);
163 *next_slot += 1;
164 }
165 }
166 }
167 Pattern::Constructor(_, bindings) => {
168 for name in bindings {
169 if name != "_" && !local_slots.contains_key(name) {
170 local_slots.insert(name.clone(), *next_slot);
171 *next_slot += 1;
172 }
173 }
174 }
175 Pattern::Tuple(items) => {
176 for item in items {
177 collect_pattern_bindings(item, local_slots, next_slot);
178 }
179 }
180 Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
181 }
182}
183
184fn resolve_expr(expr: &mut Expr, local_slots: &HashMap<String, u16>) {
186 match expr {
187 Expr::Ident(name) => {
188 if let Some(&slot) = local_slots.get(name) {
189 *expr = Expr::Resolved(slot);
190 }
191 }
193 Expr::Resolved(_) | Expr::Literal(_) => {}
194 Expr::Attr(obj, _) => {
195 resolve_expr(obj, local_slots);
196 }
197 Expr::FnCall(func, args) => {
198 resolve_expr(func, local_slots);
199 for arg in args {
200 resolve_expr(arg, local_slots);
201 }
202 }
203 Expr::BinOp(_, left, right) => {
204 resolve_expr(left, local_slots);
205 resolve_expr(right, local_slots);
206 }
207 Expr::Match { subject, arms, .. } => {
208 resolve_expr(subject, local_slots);
209 for arm in arms {
210 resolve_expr(&mut arm.body, local_slots);
211 }
212 }
213 Expr::Constructor(_, Some(inner)) => {
214 resolve_expr(inner, local_slots);
215 }
216 Expr::Constructor(_, None) => {}
217 Expr::ErrorProp(inner) => {
218 resolve_expr(inner, local_slots);
219 }
220 Expr::InterpolatedStr(parts) => {
221 for part in parts {
222 if let StrPart::Parsed(e) = part {
223 resolve_expr(e, local_slots);
224 }
225 }
226 }
227 Expr::List(elements) => {
228 for elem in elements {
229 resolve_expr(elem, local_slots);
230 }
231 }
232 Expr::Tuple(items) => {
233 for item in items {
234 resolve_expr(item, local_slots);
235 }
236 }
237 Expr::MapLiteral(entries) => {
238 for (key, value) in entries {
239 resolve_expr(key, local_slots);
240 resolve_expr(value, local_slots);
241 }
242 }
243 Expr::RecordCreate { fields, .. } => {
244 for (_, expr) in fields {
245 resolve_expr(expr, local_slots);
246 }
247 }
248 Expr::RecordUpdate { base, updates, .. } => {
249 resolve_expr(base, local_slots);
250 for (_, expr) in updates {
251 resolve_expr(expr, local_slots);
252 }
253 }
254 Expr::TailCall(boxed) => {
255 for arg in &mut boxed.1 {
256 resolve_expr(arg, local_slots);
257 }
258 }
259 }
260}
261
262fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
264 for stmt in stmts {
265 match stmt {
266 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
267 resolve_expr(expr, local_slots);
268 }
269 }
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn resolves_param_to_slot() {
279 let mut fd = FnDef {
280 name: "add".to_string(),
281 line: 1,
282 params: vec![
283 ("a".to_string(), "Int".to_string()),
284 ("b".to_string(), "Int".to_string()),
285 ],
286 return_type: "Int".to_string(),
287 effects: vec![],
288 desc: None,
289 body: Rc::new(FnBody::from_expr(Expr::BinOp(
290 BinOp::Add,
291 Box::new(Expr::Ident("a".to_string())),
292 Box::new(Expr::Ident("b".to_string())),
293 ))),
294 resolution: None,
295 };
296 resolve_fn(&mut fd);
297 let res = fd.resolution.as_ref().unwrap();
298 assert_eq!(res.local_slots["a"], 0);
299 assert_eq!(res.local_slots["b"], 1);
300 assert_eq!(res.local_count, 2);
301
302 match fd.body.tail_expr() {
303 Some(Expr::BinOp(_, left, right)) => {
304 assert_eq!(**left, Expr::Resolved(0));
305 assert_eq!(**right, Expr::Resolved(1));
306 }
307 other => panic!("unexpected body: {:?}", other),
308 }
309 }
310
311 #[test]
312 fn leaves_globals_as_ident() {
313 let mut fd = FnDef {
314 name: "f".to_string(),
315 line: 1,
316 params: vec![("x".to_string(), "Int".to_string())],
317 return_type: "Int".to_string(),
318 effects: vec![],
319 desc: None,
320 body: Rc::new(FnBody::from_expr(Expr::FnCall(
321 Box::new(Expr::Ident("Console".to_string())),
322 vec![Expr::Ident("x".to_string())],
323 ))),
324 resolution: None,
325 };
326 resolve_fn(&mut fd);
327 match fd.body.tail_expr() {
328 Some(Expr::FnCall(func, args)) => {
329 assert_eq!(**func, Expr::Ident("Console".to_string()));
330 assert_eq!(args[0], Expr::Resolved(0));
331 }
332 other => panic!("unexpected body: {:?}", other),
333 }
334 }
335
336 #[test]
337 fn resolves_val_in_block_body() {
338 let mut fd = FnDef {
339 name: "f".to_string(),
340 line: 1,
341 params: vec![("x".to_string(), "Int".to_string())],
342 return_type: "Int".to_string(),
343 effects: vec![],
344 desc: None,
345 body: Rc::new(FnBody::Block(vec![
346 Stmt::Binding(
347 "y".to_string(),
348 None,
349 Expr::BinOp(
350 BinOp::Add,
351 Box::new(Expr::Ident("x".to_string())),
352 Box::new(Expr::Literal(Literal::Int(1))),
353 ),
354 ),
355 Stmt::Expr(Expr::Ident("y".to_string())),
356 ])),
357 resolution: None,
358 };
359 resolve_fn(&mut fd);
360 let res = fd.resolution.as_ref().unwrap();
361 assert_eq!(res.local_slots["x"], 0);
362 assert_eq!(res.local_slots["y"], 1);
363 assert_eq!(res.local_count, 2);
364
365 let stmts = fd.body.stmts();
366 match &stmts[0] {
368 Stmt::Binding(_, _, Expr::BinOp(_, left, _)) => {
369 assert_eq!(**left, Expr::Resolved(0));
370 }
371 other => panic!("unexpected stmt: {:?}", other),
372 }
373 match &stmts[1] {
375 Stmt::Expr(Expr::Resolved(1)) => {}
376 other => panic!("unexpected stmt: {:?}", other),
377 }
378 }
379
380 #[test]
381 fn resolves_match_pattern_bindings() {
382 let mut fd = FnDef {
384 name: "f".to_string(),
385 line: 1,
386 params: vec![("x".to_string(), "Int".to_string())],
387 return_type: "Int".to_string(),
388 effects: vec![],
389 desc: None,
390 body: Rc::new(FnBody::from_expr(Expr::Match {
391 subject: Box::new(Expr::Ident("x".to_string())),
392 arms: vec![
393 MatchArm {
394 pattern: Pattern::Constructor(
395 "Result.Ok".to_string(),
396 vec!["v".to_string()],
397 ),
398 body: Box::new(Expr::Ident("v".to_string())),
399 },
400 MatchArm {
401 pattern: Pattern::Wildcard,
402 body: Box::new(Expr::Literal(Literal::Int(0))),
403 },
404 ],
405 line: 1,
406 })),
407 resolution: None,
408 };
409 resolve_fn(&mut fd);
410 let res = fd.resolution.as_ref().unwrap();
411 assert_eq!(res.local_slots["v"], 1);
413
414 match fd.body.tail_expr() {
415 Some(Expr::Match { arms, .. }) => {
416 assert_eq!(*arms[0].body, Expr::Resolved(1));
417 }
418 other => panic!("unexpected body: {:?}", other),
419 }
420 }
421}