1use std::collections::HashMap;
13use std::sync::Arc as Rc;
14
15use crate::ast::*;
16
17pub fn resolve_program(items: &mut [TopLevel]) {
20 for item in items.iter_mut() {
21 if let TopLevel::FnDef(fd) = item {
22 resolve_fn(fd);
23 }
24 }
25 crate::ir::last_use::annotate_program_last_use(items);
26}
27
28fn resolve_fn(fd: &mut FnDef) {
30 let mut local_slots: HashMap<String, u16> = HashMap::new();
31 let mut next_slot: u16 = 0;
32
33 for (param_name, _) in &fd.params {
35 local_slots.insert(param_name.clone(), next_slot);
36 next_slot += 1;
37 }
38
39 collect_binding_slots(fd.body.stmts(), &mut local_slots, &mut next_slot);
41
42 let mut body = fd.body.as_ref().clone();
44 resolve_stmts(body.stmts_mut(), &local_slots);
45 fd.body = Rc::new(body);
46
47 fd.resolution = Some(FnResolution {
48 local_count: next_slot,
49 local_slots: Rc::new(local_slots),
50 });
51}
52
53fn collect_binding_slots(
56 stmts: &[Stmt],
57 local_slots: &mut HashMap<String, u16>,
58 next_slot: &mut u16,
59) {
60 for stmt in stmts {
61 match stmt {
62 Stmt::Binding(name, _, expr) => {
63 if !local_slots.contains_key(name) {
64 local_slots.insert(name.clone(), *next_slot);
65 *next_slot += 1;
66 }
67 collect_expr_bindings(expr, local_slots, next_slot);
68 }
69 Stmt::Expr(expr) => {
70 collect_expr_bindings(expr, local_slots, next_slot);
71 }
72 }
73 }
74}
75
76fn collect_expr_bindings(
78 expr: &Spanned<Expr>,
79 local_slots: &mut HashMap<String, u16>,
80 next_slot: &mut u16,
81) {
82 match &expr.node {
83 Expr::Match { subject, arms } => {
84 collect_expr_bindings(subject, local_slots, next_slot);
85 for arm in arms {
86 collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
87 collect_expr_bindings(&arm.body, local_slots, next_slot);
88 }
89 }
90 Expr::BinOp(_, left, right) => {
91 collect_expr_bindings(left, local_slots, next_slot);
92 collect_expr_bindings(right, local_slots, next_slot);
93 }
94 Expr::FnCall(func, args) => {
95 collect_expr_bindings(func, local_slots, next_slot);
96 for arg in args {
97 collect_expr_bindings(arg, local_slots, next_slot);
98 }
99 }
100 Expr::ErrorProp(inner) => {
101 collect_expr_bindings(inner, local_slots, next_slot);
102 }
103 Expr::Constructor(_, Some(inner)) => {
104 collect_expr_bindings(inner, local_slots, next_slot);
105 }
106 Expr::List(elements) => {
107 for elem in elements {
108 collect_expr_bindings(elem, local_slots, next_slot);
109 }
110 }
111 Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
112 for item in items {
113 collect_expr_bindings(item, local_slots, next_slot);
114 }
115 }
116 Expr::MapLiteral(entries) => {
117 for (key, value) in entries {
118 collect_expr_bindings(key, local_slots, next_slot);
119 collect_expr_bindings(value, local_slots, next_slot);
120 }
121 }
122 Expr::InterpolatedStr(parts) => {
123 for part in parts {
124 if let StrPart::Parsed(e) = part {
125 collect_expr_bindings(e, local_slots, next_slot);
126 }
127 }
128 }
129 Expr::RecordCreate { fields, .. } => {
130 for (_, expr) in fields {
131 collect_expr_bindings(expr, local_slots, next_slot);
132 }
133 }
134 Expr::RecordUpdate { base, updates, .. } => {
135 collect_expr_bindings(base, local_slots, next_slot);
136 for (_, expr) in updates {
137 collect_expr_bindings(expr, local_slots, next_slot);
138 }
139 }
140 Expr::Attr(obj, _) => {
141 collect_expr_bindings(obj, local_slots, next_slot);
142 }
143 Expr::TailCall(boxed) => {
144 for arg in &boxed.args {
145 collect_expr_bindings(arg, local_slots, next_slot);
146 }
147 }
148 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
150 }
151}
152
153fn collect_pattern_bindings(
155 pattern: &Pattern,
156 local_slots: &mut HashMap<String, u16>,
157 next_slot: &mut u16,
158) {
159 match pattern {
160 Pattern::Ident(name) => {
161 if !local_slots.contains_key(name) {
162 local_slots.insert(name.clone(), *next_slot);
163 *next_slot += 1;
164 }
165 }
166 Pattern::Cons(head, tail) => {
167 for name in [head, tail] {
168 if name != "_" && !local_slots.contains_key(name) {
169 local_slots.insert(name.clone(), *next_slot);
170 *next_slot += 1;
171 }
172 }
173 }
174 Pattern::Constructor(_, bindings) => {
175 for name in bindings {
176 if name != "_" && !local_slots.contains_key(name) {
177 local_slots.insert(name.clone(), *next_slot);
178 *next_slot += 1;
179 }
180 }
181 }
182 Pattern::Tuple(items) => {
183 for item in items {
184 collect_pattern_bindings(item, local_slots, next_slot);
185 }
186 }
187 Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
188 }
189}
190
191fn resolve_expr(expr: &mut Spanned<Expr>, local_slots: &HashMap<String, u16>) {
193 match &mut expr.node {
194 Expr::Ident(name) => {
195 if let Some(&slot) = local_slots.get(name) {
196 expr.node = Expr::Resolved {
197 slot,
198 name: name.clone(),
199 last_use: AnnotBool(false),
200 };
201 }
202 }
204 Expr::Resolved { .. } | Expr::Literal(_) => {}
205 Expr::Attr(obj, _) => {
206 resolve_expr(obj, local_slots);
207 }
208 Expr::FnCall(func, args) => {
209 resolve_expr(func, local_slots);
210 for arg in args {
211 resolve_expr(arg, local_slots);
212 }
213 }
214 Expr::BinOp(_, left, right) => {
215 resolve_expr(left, local_slots);
216 resolve_expr(right, local_slots);
217 }
218 Expr::Match { subject, arms } => {
219 resolve_expr(subject, local_slots);
220 for arm in arms {
221 resolve_expr(&mut arm.body, local_slots);
222 }
223 }
224 Expr::Constructor(_, Some(inner)) => {
225 resolve_expr(inner, local_slots);
226 }
227 Expr::Constructor(_, None) => {}
228 Expr::ErrorProp(inner) => {
229 resolve_expr(inner, local_slots);
230 }
231 Expr::InterpolatedStr(parts) => {
232 for part in parts {
233 if let StrPart::Parsed(e) = part {
234 resolve_expr(e, local_slots);
235 }
236 }
237 }
238 Expr::List(elements) => {
239 for elem in elements {
240 resolve_expr(elem, local_slots);
241 }
242 }
243 Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
244 for item in items {
245 resolve_expr(item, local_slots);
246 }
247 }
248 Expr::MapLiteral(entries) => {
249 for (key, value) in entries {
250 resolve_expr(key, local_slots);
251 resolve_expr(value, local_slots);
252 }
253 }
254 Expr::RecordCreate { fields, .. } => {
255 for (_, expr) in fields {
256 resolve_expr(expr, local_slots);
257 }
258 }
259 Expr::RecordUpdate { base, updates, .. } => {
260 resolve_expr(base, local_slots);
261 for (_, expr) in updates {
262 resolve_expr(expr, local_slots);
263 }
264 }
265 Expr::TailCall(boxed) => {
266 for arg in &mut boxed.args {
267 resolve_expr(arg, local_slots);
268 }
269 }
270 }
271}
272
273fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
275 for stmt in stmts {
276 match stmt {
277 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
278 resolve_expr(expr, local_slots);
279 }
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn resolves_param_to_slot() {
290 let mut fd = FnDef {
291 name: "add".to_string(),
292 line: 1,
293 params: vec![
294 ("a".to_string(), "Int".to_string()),
295 ("b".to_string(), "Int".to_string()),
296 ],
297 return_type: "Int".to_string(),
298 effects: vec![],
299 desc: None,
300 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::BinOp(
301 BinOp::Add,
302 Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
303 Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
304 )))),
305 resolution: None,
306 };
307 resolve_fn(&mut fd);
308 let res = fd.resolution.as_ref().unwrap();
309 assert_eq!(res.local_slots["a"], 0);
310 assert_eq!(res.local_slots["b"], 1);
311 assert_eq!(res.local_count, 2);
312
313 match fd.body.tail_expr() {
314 Some(Spanned {
315 node: Expr::BinOp(_, left, right),
316 ..
317 }) => {
318 assert_eq!(
319 left.node,
320 Expr::Resolved {
321 slot: 0,
322 name: "a".to_string(),
323 last_use: AnnotBool(false)
324 }
325 );
326 assert_eq!(
327 right.node,
328 Expr::Resolved {
329 slot: 1,
330 name: "b".to_string(),
331 last_use: AnnotBool(false)
332 }
333 );
334 }
335 other => panic!("unexpected body: {:?}", other),
336 }
337 }
338
339 #[test]
340 fn leaves_globals_as_ident() {
341 let mut fd = FnDef {
342 name: "f".to_string(),
343 line: 1,
344 params: vec![("x".to_string(), "Int".to_string())],
345 return_type: "Int".to_string(),
346 effects: vec![],
347 desc: None,
348 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::FnCall(
349 Box::new(Spanned::bare(Expr::Ident("Console".to_string()))),
350 vec![Spanned::bare(Expr::Ident("x".to_string()))],
351 )))),
352 resolution: None,
353 };
354 resolve_fn(&mut fd);
355 match fd.body.tail_expr() {
356 Some(Spanned {
357 node: Expr::FnCall(func, args),
358 ..
359 }) => {
360 assert_eq!(func.node, Expr::Ident("Console".to_string()));
361 assert_eq!(
362 args[0].node,
363 Expr::Resolved {
364 slot: 0,
365 name: "x".to_string(),
366 last_use: AnnotBool(false)
367 }
368 );
369 }
370 other => panic!("unexpected body: {:?}", other),
371 }
372 }
373
374 #[test]
375 fn resolves_val_in_block_body() {
376 let mut fd = FnDef {
377 name: "f".to_string(),
378 line: 1,
379 params: vec![("x".to_string(), "Int".to_string())],
380 return_type: "Int".to_string(),
381 effects: vec![],
382 desc: None,
383 body: Rc::new(FnBody::Block(vec![
384 Stmt::Binding(
385 "y".to_string(),
386 None,
387 Spanned::bare(Expr::BinOp(
388 BinOp::Add,
389 Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
390 Box::new(Spanned::bare(Expr::Literal(Literal::Int(1)))),
391 )),
392 ),
393 Stmt::Expr(Spanned::bare(Expr::Ident("y".to_string()))),
394 ])),
395 resolution: None,
396 };
397 resolve_fn(&mut fd);
398 let res = fd.resolution.as_ref().unwrap();
399 assert_eq!(res.local_slots["x"], 0);
400 assert_eq!(res.local_slots["y"], 1);
401 assert_eq!(res.local_count, 2);
402
403 let stmts = fd.body.stmts();
404 match &stmts[0] {
406 Stmt::Binding(
407 _,
408 _,
409 Spanned {
410 node: Expr::BinOp(_, left, _),
411 ..
412 },
413 ) => {
414 assert_eq!(
415 left.node,
416 Expr::Resolved {
417 slot: 0,
418 name: "x".to_string(),
419 last_use: AnnotBool(false)
420 }
421 );
422 }
423 other => panic!("unexpected stmt: {:?}", other),
424 }
425 match &stmts[1] {
427 Stmt::Expr(Spanned {
428 node: Expr::Resolved { slot: 1, .. },
429 ..
430 }) => {}
431 other => panic!("unexpected stmt: {:?}", other),
432 }
433 }
434
435 #[test]
436 fn resolves_match_pattern_bindings() {
437 let mut fd = FnDef {
439 name: "f".to_string(),
440 line: 1,
441 params: vec![("x".to_string(), "Int".to_string())],
442 return_type: "Int".to_string(),
443 effects: vec![],
444 desc: None,
445 body: Rc::new(FnBody::from_expr(Spanned::new(
446 Expr::Match {
447 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
448 arms: vec![
449 MatchArm {
450 pattern: Pattern::Constructor(
451 "Result.Ok".to_string(),
452 vec!["v".to_string()],
453 ),
454 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
455 },
456 MatchArm {
457 pattern: Pattern::Wildcard,
458 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
459 },
460 ],
461 },
462 1,
463 ))),
464 resolution: None,
465 };
466 resolve_fn(&mut fd);
467 let res = fd.resolution.as_ref().unwrap();
468 assert_eq!(res.local_slots["v"], 1);
470
471 match fd.body.tail_expr() {
472 Some(Spanned {
473 node: Expr::Match { arms, .. },
474 ..
475 }) => {
476 assert_eq!(
477 arms[0].body.node,
478 Expr::Resolved {
479 slot: 1,
480 name: "v".to_string(),
481 last_use: AnnotBool(false)
482 }
483 );
484 }
485 other => panic!("unexpected body: {:?}", other),
486 }
487 }
488
489 #[test]
490 fn resolves_match_pattern_bindings_inside_binding_initializer() {
491 let mut fd = FnDef {
492 name: "f".to_string(),
493 line: 1,
494 params: vec![("x".to_string(), "Int".to_string())],
495 return_type: "Int".to_string(),
496 effects: vec![],
497 desc: None,
498 body: Rc::new(FnBody::Block(vec![
499 Stmt::Binding(
500 "result".to_string(),
501 None,
502 Spanned::bare(Expr::Match {
503 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
504 arms: vec![
505 MatchArm {
506 pattern: Pattern::Constructor(
507 "Option.Some".to_string(),
508 vec!["v".to_string()],
509 ),
510 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
511 },
512 MatchArm {
513 pattern: Pattern::Wildcard,
514 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
515 },
516 ],
517 }),
518 ),
519 Stmt::Expr(Spanned::bare(Expr::Ident("result".to_string()))),
520 ])),
521 resolution: None,
522 };
523
524 resolve_fn(&mut fd);
525 let res = fd.resolution.as_ref().unwrap();
526 assert_eq!(res.local_slots["x"], 0);
527 assert_eq!(res.local_slots["result"], 1);
528 assert_eq!(res.local_slots["v"], 2);
529
530 let stmts = fd.body.stmts();
531 match &stmts[0] {
532 Stmt::Binding(
533 _,
534 _,
535 Spanned {
536 node: Expr::Match { arms, .. },
537 ..
538 },
539 ) => {
540 assert_eq!(
541 arms[0].body.node,
542 Expr::Resolved {
543 slot: 2,
544 name: "v".to_string(),
545 last_use: AnnotBool(false)
546 }
547 );
548 }
549 other => panic!("unexpected stmt: {:?}", other),
550 }
551
552 match &stmts[1] {
553 Stmt::Expr(Spanned {
554 node: Expr::Resolved { slot: 1, .. },
555 ..
556 }) => {}
557 other => panic!("unexpected stmt: {:?}", other),
558 }
559 }
560}