1use std::collections::HashMap;
13use std::sync::Arc as Rc;
14
15use crate::ast::*;
16
17pub fn resolve_program(items: &mut [TopLevel]) {
19 for item in items.iter_mut() {
20 if let TopLevel::FnDef(fd) = item {
21 resolve_fn(fd);
22 }
23 }
24}
25
26fn resolve_fn(fd: &mut FnDef) {
28 let mut local_slots: HashMap<String, u16> = HashMap::new();
29 let mut next_slot: u16 = 0;
30
31 for (param_name, _) in &fd.params {
33 local_slots.insert(param_name.clone(), next_slot);
34 next_slot += 1;
35 }
36
37 collect_binding_slots(fd.body.stmts(), &mut local_slots, &mut next_slot);
39
40 let mut body = fd.body.as_ref().clone();
42 resolve_stmts(body.stmts_mut(), &local_slots);
43 fd.body = Rc::new(body);
44
45 fd.resolution = Some(FnResolution {
46 local_count: next_slot,
47 local_slots: Rc::new(local_slots),
48 });
49}
50
51fn collect_binding_slots(
54 stmts: &[Stmt],
55 local_slots: &mut HashMap<String, u16>,
56 next_slot: &mut u16,
57) {
58 for stmt in stmts {
59 match stmt {
60 Stmt::Binding(name, _, expr) => {
61 if !local_slots.contains_key(name) {
62 local_slots.insert(name.clone(), *next_slot);
63 *next_slot += 1;
64 }
65 collect_expr_bindings(expr, local_slots, next_slot);
66 }
67 Stmt::Expr(expr) => {
68 collect_expr_bindings(expr, local_slots, next_slot);
69 }
70 }
71 }
72}
73
74fn collect_expr_bindings(
76 expr: &Spanned<Expr>,
77 local_slots: &mut HashMap<String, u16>,
78 next_slot: &mut u16,
79) {
80 match &expr.node {
81 Expr::Match { subject, arms } => {
82 collect_expr_bindings(subject, local_slots, next_slot);
83 for arm in arms {
84 collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
85 collect_expr_bindings(&arm.body, local_slots, next_slot);
86 }
87 }
88 Expr::BinOp(_, left, right) => {
89 collect_expr_bindings(left, local_slots, next_slot);
90 collect_expr_bindings(right, local_slots, next_slot);
91 }
92 Expr::FnCall(func, args) => {
93 collect_expr_bindings(func, local_slots, next_slot);
94 for arg in args {
95 collect_expr_bindings(arg, local_slots, next_slot);
96 }
97 }
98 Expr::ErrorProp(inner) => {
99 collect_expr_bindings(inner, local_slots, next_slot);
100 }
101 Expr::Constructor(_, Some(inner)) => {
102 collect_expr_bindings(inner, local_slots, next_slot);
103 }
104 Expr::List(elements) => {
105 for elem in elements {
106 collect_expr_bindings(elem, local_slots, next_slot);
107 }
108 }
109 Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
110 for item in items {
111 collect_expr_bindings(item, local_slots, next_slot);
112 }
113 }
114 Expr::MapLiteral(entries) => {
115 for (key, value) in entries {
116 collect_expr_bindings(key, local_slots, next_slot);
117 collect_expr_bindings(value, local_slots, next_slot);
118 }
119 }
120 Expr::InterpolatedStr(parts) => {
121 for part in parts {
122 if let StrPart::Parsed(e) = part {
123 collect_expr_bindings(e, local_slots, next_slot);
124 }
125 }
126 }
127 Expr::RecordCreate { fields, .. } => {
128 for (_, expr) in fields {
129 collect_expr_bindings(expr, local_slots, next_slot);
130 }
131 }
132 Expr::RecordUpdate { base, updates, .. } => {
133 collect_expr_bindings(base, local_slots, next_slot);
134 for (_, expr) in updates {
135 collect_expr_bindings(expr, local_slots, next_slot);
136 }
137 }
138 Expr::Attr(obj, _) => {
139 collect_expr_bindings(obj, local_slots, next_slot);
140 }
141 Expr::TailCall(boxed) => {
142 for arg in &boxed.1 {
143 collect_expr_bindings(arg, local_slots, next_slot);
144 }
145 }
146 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
148 }
149}
150
151fn collect_pattern_bindings(
153 pattern: &Pattern,
154 local_slots: &mut HashMap<String, u16>,
155 next_slot: &mut u16,
156) {
157 match pattern {
158 Pattern::Ident(name) => {
159 if !local_slots.contains_key(name) {
160 local_slots.insert(name.clone(), *next_slot);
161 *next_slot += 1;
162 }
163 }
164 Pattern::Cons(head, tail) => {
165 for name in [head, tail] {
166 if name != "_" && !local_slots.contains_key(name) {
167 local_slots.insert(name.clone(), *next_slot);
168 *next_slot += 1;
169 }
170 }
171 }
172 Pattern::Constructor(_, bindings) => {
173 for name in bindings {
174 if name != "_" && !local_slots.contains_key(name) {
175 local_slots.insert(name.clone(), *next_slot);
176 *next_slot += 1;
177 }
178 }
179 }
180 Pattern::Tuple(items) => {
181 for item in items {
182 collect_pattern_bindings(item, local_slots, next_slot);
183 }
184 }
185 Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
186 }
187}
188
189fn resolve_expr(expr: &mut Spanned<Expr>, local_slots: &HashMap<String, u16>) {
191 match &mut expr.node {
192 Expr::Ident(name) => {
193 if let Some(&slot) = local_slots.get(name) {
194 expr.node = Expr::Resolved(slot);
195 }
196 }
198 Expr::Resolved(_) | Expr::Literal(_) => {}
199 Expr::Attr(obj, _) => {
200 resolve_expr(obj, local_slots);
201 }
202 Expr::FnCall(func, args) => {
203 resolve_expr(func, local_slots);
204 for arg in args {
205 resolve_expr(arg, local_slots);
206 }
207 }
208 Expr::BinOp(_, left, right) => {
209 resolve_expr(left, local_slots);
210 resolve_expr(right, local_slots);
211 }
212 Expr::Match { subject, arms } => {
213 resolve_expr(subject, local_slots);
214 for arm in arms {
215 resolve_expr(&mut arm.body, local_slots);
216 }
217 }
218 Expr::Constructor(_, Some(inner)) => {
219 resolve_expr(inner, local_slots);
220 }
221 Expr::Constructor(_, None) => {}
222 Expr::ErrorProp(inner) => {
223 resolve_expr(inner, local_slots);
224 }
225 Expr::InterpolatedStr(parts) => {
226 for part in parts {
227 if let StrPart::Parsed(e) = part {
228 resolve_expr(e, local_slots);
229 }
230 }
231 }
232 Expr::List(elements) => {
233 for elem in elements {
234 resolve_expr(elem, local_slots);
235 }
236 }
237 Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
238 for item in items {
239 resolve_expr(item, local_slots);
240 }
241 }
242 Expr::MapLiteral(entries) => {
243 for (key, value) in entries {
244 resolve_expr(key, local_slots);
245 resolve_expr(value, local_slots);
246 }
247 }
248 Expr::RecordCreate { fields, .. } => {
249 for (_, expr) in fields {
250 resolve_expr(expr, local_slots);
251 }
252 }
253 Expr::RecordUpdate { base, updates, .. } => {
254 resolve_expr(base, local_slots);
255 for (_, expr) in updates {
256 resolve_expr(expr, local_slots);
257 }
258 }
259 Expr::TailCall(boxed) => {
260 for arg in &mut boxed.1 {
261 resolve_expr(arg, local_slots);
262 }
263 }
264 }
265}
266
267fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
269 for stmt in stmts {
270 match stmt {
271 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
272 resolve_expr(expr, local_slots);
273 }
274 }
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn resolves_param_to_slot() {
284 let mut fd = FnDef {
285 name: "add".to_string(),
286 line: 1,
287 params: vec![
288 ("a".to_string(), "Int".to_string()),
289 ("b".to_string(), "Int".to_string()),
290 ],
291 return_type: "Int".to_string(),
292 effects: vec![],
293 desc: None,
294 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::BinOp(
295 BinOp::Add,
296 Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
297 Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
298 )))),
299 resolution: None,
300 };
301 resolve_fn(&mut fd);
302 let res = fd.resolution.as_ref().unwrap();
303 assert_eq!(res.local_slots["a"], 0);
304 assert_eq!(res.local_slots["b"], 1);
305 assert_eq!(res.local_count, 2);
306
307 match fd.body.tail_expr() {
308 Some(Spanned {
309 node: Expr::BinOp(_, left, right),
310 ..
311 }) => {
312 assert_eq!(left.node, Expr::Resolved(0));
313 assert_eq!(right.node, Expr::Resolved(1));
314 }
315 other => panic!("unexpected body: {:?}", other),
316 }
317 }
318
319 #[test]
320 fn leaves_globals_as_ident() {
321 let mut fd = FnDef {
322 name: "f".to_string(),
323 line: 1,
324 params: vec![("x".to_string(), "Int".to_string())],
325 return_type: "Int".to_string(),
326 effects: vec![],
327 desc: None,
328 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::FnCall(
329 Box::new(Spanned::bare(Expr::Ident("Console".to_string()))),
330 vec![Spanned::bare(Expr::Ident("x".to_string()))],
331 )))),
332 resolution: None,
333 };
334 resolve_fn(&mut fd);
335 match fd.body.tail_expr() {
336 Some(Spanned {
337 node: Expr::FnCall(func, args),
338 ..
339 }) => {
340 assert_eq!(func.node, Expr::Ident("Console".to_string()));
341 assert_eq!(args[0].node, Expr::Resolved(0));
342 }
343 other => panic!("unexpected body: {:?}", other),
344 }
345 }
346
347 #[test]
348 fn resolves_val_in_block_body() {
349 let mut fd = FnDef {
350 name: "f".to_string(),
351 line: 1,
352 params: vec![("x".to_string(), "Int".to_string())],
353 return_type: "Int".to_string(),
354 effects: vec![],
355 desc: None,
356 body: Rc::new(FnBody::Block(vec![
357 Stmt::Binding(
358 "y".to_string(),
359 None,
360 Spanned::bare(Expr::BinOp(
361 BinOp::Add,
362 Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
363 Box::new(Spanned::bare(Expr::Literal(Literal::Int(1)))),
364 )),
365 ),
366 Stmt::Expr(Spanned::bare(Expr::Ident("y".to_string()))),
367 ])),
368 resolution: None,
369 };
370 resolve_fn(&mut fd);
371 let res = fd.resolution.as_ref().unwrap();
372 assert_eq!(res.local_slots["x"], 0);
373 assert_eq!(res.local_slots["y"], 1);
374 assert_eq!(res.local_count, 2);
375
376 let stmts = fd.body.stmts();
377 match &stmts[0] {
379 Stmt::Binding(
380 _,
381 _,
382 Spanned {
383 node: Expr::BinOp(_, left, _),
384 ..
385 },
386 ) => {
387 assert_eq!(left.node, Expr::Resolved(0));
388 }
389 other => panic!("unexpected stmt: {:?}", other),
390 }
391 match &stmts[1] {
393 Stmt::Expr(Spanned {
394 node: Expr::Resolved(1),
395 ..
396 }) => {}
397 other => panic!("unexpected stmt: {:?}", other),
398 }
399 }
400
401 #[test]
402 fn resolves_match_pattern_bindings() {
403 let mut fd = FnDef {
405 name: "f".to_string(),
406 line: 1,
407 params: vec![("x".to_string(), "Int".to_string())],
408 return_type: "Int".to_string(),
409 effects: vec![],
410 desc: None,
411 body: Rc::new(FnBody::from_expr(Spanned::new(
412 Expr::Match {
413 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
414 arms: vec![
415 MatchArm {
416 pattern: Pattern::Constructor(
417 "Result.Ok".to_string(),
418 vec!["v".to_string()],
419 ),
420 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
421 },
422 MatchArm {
423 pattern: Pattern::Wildcard,
424 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
425 },
426 ],
427 },
428 1,
429 ))),
430 resolution: None,
431 };
432 resolve_fn(&mut fd);
433 let res = fd.resolution.as_ref().unwrap();
434 assert_eq!(res.local_slots["v"], 1);
436
437 match fd.body.tail_expr() {
438 Some(Spanned {
439 node: Expr::Match { arms, .. },
440 ..
441 }) => {
442 assert_eq!(arms[0].body.node, Expr::Resolved(1));
443 }
444 other => panic!("unexpected body: {:?}", other),
445 }
446 }
447
448 #[test]
449 fn resolves_match_pattern_bindings_inside_binding_initializer() {
450 let mut fd = FnDef {
451 name: "f".to_string(),
452 line: 1,
453 params: vec![("x".to_string(), "Int".to_string())],
454 return_type: "Int".to_string(),
455 effects: vec![],
456 desc: None,
457 body: Rc::new(FnBody::Block(vec![
458 Stmt::Binding(
459 "result".to_string(),
460 None,
461 Spanned::bare(Expr::Match {
462 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
463 arms: vec![
464 MatchArm {
465 pattern: Pattern::Constructor(
466 "Option.Some".to_string(),
467 vec!["v".to_string()],
468 ),
469 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
470 },
471 MatchArm {
472 pattern: Pattern::Wildcard,
473 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
474 },
475 ],
476 }),
477 ),
478 Stmt::Expr(Spanned::bare(Expr::Ident("result".to_string()))),
479 ])),
480 resolution: None,
481 };
482
483 resolve_fn(&mut fd);
484 let res = fd.resolution.as_ref().unwrap();
485 assert_eq!(res.local_slots["x"], 0);
486 assert_eq!(res.local_slots["result"], 1);
487 assert_eq!(res.local_slots["v"], 2);
488
489 let stmts = fd.body.stmts();
490 match &stmts[0] {
491 Stmt::Binding(
492 _,
493 _,
494 Spanned {
495 node: Expr::Match { arms, .. },
496 ..
497 },
498 ) => {
499 assert_eq!(arms[0].body.node, Expr::Resolved(2));
500 }
501 other => panic!("unexpected stmt: {:?}", other),
502 }
503
504 match &stmts[1] {
505 Stmt::Expr(Spanned {
506 node: Expr::Resolved(1),
507 ..
508 }) => {}
509 other => panic!("unexpected stmt: {:?}", other),
510 }
511 }
512}