1use std::collections::HashMap;
13use std::sync::Arc as Rc;
14
15use crate::ast::*;
16
17pub fn resolve_program(items: &mut [TopLevel]) {
22 for item in items.iter_mut() {
23 if let TopLevel::FnDef(fd) = item {
24 resolve_fn(fd);
25 }
26 }
27}
28
29fn resolve_fn(fd: &mut FnDef) {
31 let mut local_slots: HashMap<String, u16> = HashMap::new();
32 let mut next_slot: u16 = 0;
33
34 for (param_name, _) in &fd.params {
36 local_slots.insert(param_name.clone(), next_slot);
37 next_slot += 1;
38 }
39
40 collect_binding_slots(fd.body.stmts(), &mut local_slots, &mut next_slot);
42
43 let mut body = fd.body.as_ref().clone();
45 resolve_stmts(body.stmts_mut(), &local_slots);
46 fd.body = Rc::new(body);
47
48 fd.resolution = Some(FnResolution {
49 local_count: next_slot,
50 local_slots: Rc::new(local_slots),
51 });
52}
53
54fn collect_binding_slots(
57 stmts: &[Stmt],
58 local_slots: &mut HashMap<String, u16>,
59 next_slot: &mut u16,
60) {
61 for stmt in stmts {
62 match stmt {
63 Stmt::Binding(name, _, expr) => {
64 if !local_slots.contains_key(name) {
65 local_slots.insert(name.clone(), *next_slot);
66 *next_slot += 1;
67 }
68 collect_expr_bindings(expr, local_slots, next_slot);
69 }
70 Stmt::Expr(expr) => {
71 collect_expr_bindings(expr, local_slots, next_slot);
72 }
73 }
74 }
75}
76
77fn collect_expr_bindings(
79 expr: &Spanned<Expr>,
80 local_slots: &mut HashMap<String, u16>,
81 next_slot: &mut u16,
82) {
83 match &expr.node {
84 Expr::Match { subject, arms } => {
85 collect_expr_bindings(subject, local_slots, next_slot);
86 for arm in arms {
87 collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
88 collect_expr_bindings(&arm.body, local_slots, next_slot);
89 }
90 }
91 Expr::BinOp(_, left, right) => {
92 collect_expr_bindings(left, local_slots, next_slot);
93 collect_expr_bindings(right, local_slots, next_slot);
94 }
95 Expr::FnCall(func, args) => {
96 collect_expr_bindings(func, local_slots, next_slot);
97 for arg in args {
98 collect_expr_bindings(arg, local_slots, next_slot);
99 }
100 }
101 Expr::ErrorProp(inner) => {
102 collect_expr_bindings(inner, local_slots, next_slot);
103 }
104 Expr::Constructor(_, Some(inner)) => {
105 collect_expr_bindings(inner, local_slots, next_slot);
106 }
107 Expr::List(elements) => {
108 for elem in elements {
109 collect_expr_bindings(elem, local_slots, next_slot);
110 }
111 }
112 Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
113 for item in items {
114 collect_expr_bindings(item, local_slots, next_slot);
115 }
116 }
117 Expr::MapLiteral(entries) => {
118 for (key, value) in entries {
119 collect_expr_bindings(key, local_slots, next_slot);
120 collect_expr_bindings(value, local_slots, next_slot);
121 }
122 }
123 Expr::InterpolatedStr(parts) => {
124 for part in parts {
125 if let StrPart::Parsed(e) = part {
126 collect_expr_bindings(e, local_slots, next_slot);
127 }
128 }
129 }
130 Expr::RecordCreate { fields, .. } => {
131 for (_, expr) in fields {
132 collect_expr_bindings(expr, local_slots, next_slot);
133 }
134 }
135 Expr::RecordUpdate { base, updates, .. } => {
136 collect_expr_bindings(base, local_slots, next_slot);
137 for (_, expr) in updates {
138 collect_expr_bindings(expr, local_slots, next_slot);
139 }
140 }
141 Expr::Attr(obj, _) => {
142 collect_expr_bindings(obj, local_slots, next_slot);
143 }
144 Expr::TailCall(boxed) => {
145 for arg in &boxed.args {
146 collect_expr_bindings(arg, local_slots, next_slot);
147 }
148 }
149 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
151 }
152}
153
154fn collect_pattern_bindings(
156 pattern: &Pattern,
157 local_slots: &mut HashMap<String, u16>,
158 next_slot: &mut u16,
159) {
160 match pattern {
161 Pattern::Ident(name) => {
162 if !local_slots.contains_key(name) {
163 local_slots.insert(name.clone(), *next_slot);
164 *next_slot += 1;
165 }
166 }
167 Pattern::Cons(head, tail) => {
168 for name in [head, tail] {
169 if name != "_" && !local_slots.contains_key(name) {
170 local_slots.insert(name.clone(), *next_slot);
171 *next_slot += 1;
172 }
173 }
174 }
175 Pattern::Constructor(_, bindings) => {
176 for name in bindings {
177 if name != "_" && !local_slots.contains_key(name) {
178 local_slots.insert(name.clone(), *next_slot);
179 *next_slot += 1;
180 }
181 }
182 }
183 Pattern::Tuple(items) => {
184 for item in items {
185 collect_pattern_bindings(item, local_slots, next_slot);
186 }
187 }
188 Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
189 }
190}
191
192fn resolve_expr(expr: &mut Spanned<Expr>, local_slots: &HashMap<String, u16>) {
194 match &mut expr.node {
195 Expr::Ident(name) => {
196 if let Some(&slot) = local_slots.get(name) {
197 expr.node = Expr::Resolved {
198 slot,
199 name: name.clone(),
200 last_use: AnnotBool(false),
201 };
202 }
203 }
205 Expr::Resolved { .. } | Expr::Literal(_) => {}
206 Expr::Attr(obj, _) => {
207 resolve_expr(obj, local_slots);
208 }
209 Expr::FnCall(func, args) => {
210 resolve_expr(func, local_slots);
211 for arg in args {
212 resolve_expr(arg, local_slots);
213 }
214 }
215 Expr::BinOp(_, left, right) => {
216 resolve_expr(left, local_slots);
217 resolve_expr(right, local_slots);
218 }
219 Expr::Match { subject, arms } => {
220 resolve_expr(subject, local_slots);
221 for arm in arms {
222 resolve_expr(&mut arm.body, local_slots);
223 }
224 }
225 Expr::Constructor(_, Some(inner)) => {
226 resolve_expr(inner, local_slots);
227 }
228 Expr::Constructor(_, None) => {}
229 Expr::ErrorProp(inner) => {
230 resolve_expr(inner, local_slots);
231 }
232 Expr::InterpolatedStr(parts) => {
233 for part in parts {
234 if let StrPart::Parsed(e) = part {
235 resolve_expr(e, local_slots);
236 }
237 }
238 }
239 Expr::List(elements) => {
240 for elem in elements {
241 resolve_expr(elem, local_slots);
242 }
243 }
244 Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
245 for item in items {
246 resolve_expr(item, local_slots);
247 }
248 }
249 Expr::MapLiteral(entries) => {
250 for (key, value) in entries {
251 resolve_expr(key, local_slots);
252 resolve_expr(value, local_slots);
253 }
254 }
255 Expr::RecordCreate { fields, .. } => {
256 for (_, expr) in fields {
257 resolve_expr(expr, local_slots);
258 }
259 }
260 Expr::RecordUpdate { base, updates, .. } => {
261 resolve_expr(base, local_slots);
262 for (_, expr) in updates {
263 resolve_expr(expr, local_slots);
264 }
265 }
266 Expr::TailCall(boxed) => {
267 for arg in &mut boxed.args {
268 resolve_expr(arg, local_slots);
269 }
270 }
271 }
272}
273
274fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
276 for stmt in stmts {
277 match stmt {
278 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
279 resolve_expr(expr, local_slots);
280 }
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn resolves_param_to_slot() {
291 let mut fd = FnDef {
292 name: "add".to_string(),
293 line: 1,
294 params: vec![
295 ("a".to_string(), "Int".to_string()),
296 ("b".to_string(), "Int".to_string()),
297 ],
298 return_type: "Int".to_string(),
299 effects: vec![],
300 desc: None,
301 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::BinOp(
302 BinOp::Add,
303 Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
304 Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
305 )))),
306 resolution: None,
307 };
308 resolve_fn(&mut fd);
309 let res = fd.resolution.as_ref().unwrap();
310 assert_eq!(res.local_slots["a"], 0);
311 assert_eq!(res.local_slots["b"], 1);
312 assert_eq!(res.local_count, 2);
313
314 match fd.body.tail_expr() {
315 Some(Spanned {
316 node: Expr::BinOp(_, left, right),
317 ..
318 }) => {
319 assert_eq!(
320 left.node,
321 Expr::Resolved {
322 slot: 0,
323 name: "a".to_string(),
324 last_use: AnnotBool(false)
325 }
326 );
327 assert_eq!(
328 right.node,
329 Expr::Resolved {
330 slot: 1,
331 name: "b".to_string(),
332 last_use: AnnotBool(false)
333 }
334 );
335 }
336 other => panic!("unexpected body: {:?}", other),
337 }
338 }
339
340 #[test]
341 fn leaves_globals_as_ident() {
342 let mut fd = FnDef {
343 name: "f".to_string(),
344 line: 1,
345 params: vec![("x".to_string(), "Int".to_string())],
346 return_type: "Int".to_string(),
347 effects: vec![],
348 desc: None,
349 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::FnCall(
350 Box::new(Spanned::bare(Expr::Ident("Console".to_string()))),
351 vec![Spanned::bare(Expr::Ident("x".to_string()))],
352 )))),
353 resolution: None,
354 };
355 resolve_fn(&mut fd);
356 match fd.body.tail_expr() {
357 Some(Spanned {
358 node: Expr::FnCall(func, args),
359 ..
360 }) => {
361 assert_eq!(func.node, Expr::Ident("Console".to_string()));
362 assert_eq!(
363 args[0].node,
364 Expr::Resolved {
365 slot: 0,
366 name: "x".to_string(),
367 last_use: AnnotBool(false)
368 }
369 );
370 }
371 other => panic!("unexpected body: {:?}", other),
372 }
373 }
374
375 #[test]
376 fn resolves_val_in_block_body() {
377 let mut fd = FnDef {
378 name: "f".to_string(),
379 line: 1,
380 params: vec![("x".to_string(), "Int".to_string())],
381 return_type: "Int".to_string(),
382 effects: vec![],
383 desc: None,
384 body: Rc::new(FnBody::Block(vec![
385 Stmt::Binding(
386 "y".to_string(),
387 None,
388 Spanned::bare(Expr::BinOp(
389 BinOp::Add,
390 Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
391 Box::new(Spanned::bare(Expr::Literal(Literal::Int(1)))),
392 )),
393 ),
394 Stmt::Expr(Spanned::bare(Expr::Ident("y".to_string()))),
395 ])),
396 resolution: None,
397 };
398 resolve_fn(&mut fd);
399 let res = fd.resolution.as_ref().unwrap();
400 assert_eq!(res.local_slots["x"], 0);
401 assert_eq!(res.local_slots["y"], 1);
402 assert_eq!(res.local_count, 2);
403
404 let stmts = fd.body.stmts();
405 match &stmts[0] {
407 Stmt::Binding(
408 _,
409 _,
410 Spanned {
411 node: Expr::BinOp(_, left, _),
412 ..
413 },
414 ) => {
415 assert_eq!(
416 left.node,
417 Expr::Resolved {
418 slot: 0,
419 name: "x".to_string(),
420 last_use: AnnotBool(false)
421 }
422 );
423 }
424 other => panic!("unexpected stmt: {:?}", other),
425 }
426 match &stmts[1] {
428 Stmt::Expr(Spanned {
429 node: Expr::Resolved { slot: 1, .. },
430 ..
431 }) => {}
432 other => panic!("unexpected stmt: {:?}", other),
433 }
434 }
435
436 #[test]
437 fn resolves_match_pattern_bindings() {
438 let mut fd = FnDef {
440 name: "f".to_string(),
441 line: 1,
442 params: vec![("x".to_string(), "Int".to_string())],
443 return_type: "Int".to_string(),
444 effects: vec![],
445 desc: None,
446 body: Rc::new(FnBody::from_expr(Spanned::new(
447 Expr::Match {
448 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
449 arms: vec![
450 MatchArm {
451 pattern: Pattern::Constructor(
452 "Result.Ok".to_string(),
453 vec!["v".to_string()],
454 ),
455 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
456 },
457 MatchArm {
458 pattern: Pattern::Wildcard,
459 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
460 },
461 ],
462 },
463 1,
464 ))),
465 resolution: None,
466 };
467 resolve_fn(&mut fd);
468 let res = fd.resolution.as_ref().unwrap();
469 assert_eq!(res.local_slots["v"], 1);
471
472 match fd.body.tail_expr() {
473 Some(Spanned {
474 node: Expr::Match { arms, .. },
475 ..
476 }) => {
477 assert_eq!(
478 arms[0].body.node,
479 Expr::Resolved {
480 slot: 1,
481 name: "v".to_string(),
482 last_use: AnnotBool(false)
483 }
484 );
485 }
486 other => panic!("unexpected body: {:?}", other),
487 }
488 }
489
490 #[test]
491 fn resolves_match_pattern_bindings_inside_binding_initializer() {
492 let mut fd = FnDef {
493 name: "f".to_string(),
494 line: 1,
495 params: vec![("x".to_string(), "Int".to_string())],
496 return_type: "Int".to_string(),
497 effects: vec![],
498 desc: None,
499 body: Rc::new(FnBody::Block(vec![
500 Stmt::Binding(
501 "result".to_string(),
502 None,
503 Spanned::bare(Expr::Match {
504 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
505 arms: vec![
506 MatchArm {
507 pattern: Pattern::Constructor(
508 "Option.Some".to_string(),
509 vec!["v".to_string()],
510 ),
511 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
512 },
513 MatchArm {
514 pattern: Pattern::Wildcard,
515 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
516 },
517 ],
518 }),
519 ),
520 Stmt::Expr(Spanned::bare(Expr::Ident("result".to_string()))),
521 ])),
522 resolution: None,
523 };
524
525 resolve_fn(&mut fd);
526 let res = fd.resolution.as_ref().unwrap();
527 assert_eq!(res.local_slots["x"], 0);
528 assert_eq!(res.local_slots["result"], 1);
529 assert_eq!(res.local_slots["v"], 2);
530
531 let stmts = fd.body.stmts();
532 match &stmts[0] {
533 Stmt::Binding(
534 _,
535 _,
536 Spanned {
537 node: Expr::Match { arms, .. },
538 ..
539 },
540 ) => {
541 assert_eq!(
542 arms[0].body.node,
543 Expr::Resolved {
544 slot: 2,
545 name: "v".to_string(),
546 last_use: AnnotBool(false)
547 }
548 );
549 }
550 other => panic!("unexpected stmt: {:?}", other),
551 }
552
553 match &stmts[1] {
554 Stmt::Expr(Spanned {
555 node: Expr::Resolved { slot: 1, .. },
556 ..
557 }) => {}
558 other => panic!("unexpected stmt: {:?}", other),
559 }
560 }
561}