1use std::collections::HashMap;
13use std::sync::Arc as Rc;
14
15use crate::ast::*;
16
17pub fn resolve_program(items: &mut [TopLevel]) {
19 for item in items.iter_mut() {
20 if let TopLevel::FnDef(fd) = item {
21 resolve_fn(fd);
22 }
23 }
24}
25
26fn resolve_fn(fd: &mut FnDef) {
28 let mut local_slots: HashMap<String, u16> = HashMap::new();
29 let mut next_slot: u16 = 0;
30
31 for (param_name, _) in &fd.params {
33 local_slots.insert(param_name.clone(), next_slot);
34 next_slot += 1;
35 }
36
37 collect_binding_slots(fd.body.stmts(), &mut local_slots, &mut next_slot);
39
40 let mut body = fd.body.as_ref().clone();
42 resolve_stmts(body.stmts_mut(), &local_slots);
43 fd.body = Rc::new(body);
44
45 fd.resolution = Some(FnResolution {
46 local_count: next_slot,
47 local_slots: Rc::new(local_slots),
48 });
49}
50
51fn collect_binding_slots(
54 stmts: &[Stmt],
55 local_slots: &mut HashMap<String, u16>,
56 next_slot: &mut u16,
57) {
58 for stmt in stmts {
59 match stmt {
60 Stmt::Binding(name, _, _) => {
61 if !local_slots.contains_key(name) {
62 local_slots.insert(name.clone(), *next_slot);
63 *next_slot += 1;
64 }
65 }
66 Stmt::Expr(expr) => {
67 collect_expr_bindings(expr, local_slots, next_slot);
68 }
69 }
70 }
71}
72
73fn collect_expr_bindings(
75 expr: &Spanned<Expr>,
76 local_slots: &mut HashMap<String, u16>,
77 next_slot: &mut u16,
78) {
79 match &expr.node {
80 Expr::Match { subject, arms } => {
81 collect_expr_bindings(subject, local_slots, next_slot);
82 for arm in arms {
83 collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
84 collect_expr_bindings(&arm.body, local_slots, next_slot);
85 }
86 }
87 Expr::BinOp(_, left, right) => {
88 collect_expr_bindings(left, local_slots, next_slot);
89 collect_expr_bindings(right, local_slots, next_slot);
90 }
91 Expr::FnCall(func, args) => {
92 collect_expr_bindings(func, local_slots, next_slot);
93 for arg in args {
94 collect_expr_bindings(arg, local_slots, next_slot);
95 }
96 }
97 Expr::ErrorProp(inner) => {
98 collect_expr_bindings(inner, local_slots, next_slot);
99 }
100 Expr::Constructor(_, Some(inner)) => {
101 collect_expr_bindings(inner, local_slots, next_slot);
102 }
103 Expr::List(elements) => {
104 for elem in elements {
105 collect_expr_bindings(elem, local_slots, next_slot);
106 }
107 }
108 Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
109 for item in items {
110 collect_expr_bindings(item, local_slots, next_slot);
111 }
112 }
113 Expr::MapLiteral(entries) => {
114 for (key, value) in entries {
115 collect_expr_bindings(key, local_slots, next_slot);
116 collect_expr_bindings(value, local_slots, next_slot);
117 }
118 }
119 Expr::InterpolatedStr(parts) => {
120 for part in parts {
121 if let StrPart::Parsed(e) = part {
122 collect_expr_bindings(e, local_slots, next_slot);
123 }
124 }
125 }
126 Expr::RecordCreate { fields, .. } => {
127 for (_, expr) in fields {
128 collect_expr_bindings(expr, local_slots, next_slot);
129 }
130 }
131 Expr::RecordUpdate { base, updates, .. } => {
132 collect_expr_bindings(base, local_slots, next_slot);
133 for (_, expr) in updates {
134 collect_expr_bindings(expr, local_slots, next_slot);
135 }
136 }
137 Expr::Attr(obj, _) => {
138 collect_expr_bindings(obj, local_slots, next_slot);
139 }
140 Expr::TailCall(boxed) => {
141 for arg in &boxed.1 {
142 collect_expr_bindings(arg, local_slots, next_slot);
143 }
144 }
145 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
147 }
148}
149
150fn collect_pattern_bindings(
152 pattern: &Pattern,
153 local_slots: &mut HashMap<String, u16>,
154 next_slot: &mut u16,
155) {
156 match pattern {
157 Pattern::Ident(name) => {
158 if !local_slots.contains_key(name) {
159 local_slots.insert(name.clone(), *next_slot);
160 *next_slot += 1;
161 }
162 }
163 Pattern::Cons(head, tail) => {
164 for name in [head, tail] {
165 if name != "_" && !local_slots.contains_key(name) {
166 local_slots.insert(name.clone(), *next_slot);
167 *next_slot += 1;
168 }
169 }
170 }
171 Pattern::Constructor(_, bindings) => {
172 for name in bindings {
173 if name != "_" && !local_slots.contains_key(name) {
174 local_slots.insert(name.clone(), *next_slot);
175 *next_slot += 1;
176 }
177 }
178 }
179 Pattern::Tuple(items) => {
180 for item in items {
181 collect_pattern_bindings(item, local_slots, next_slot);
182 }
183 }
184 Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
185 }
186}
187
188fn resolve_expr(expr: &mut Spanned<Expr>, local_slots: &HashMap<String, u16>) {
190 match &mut expr.node {
191 Expr::Ident(name) => {
192 if let Some(&slot) = local_slots.get(name) {
193 expr.node = Expr::Resolved(slot);
194 }
195 }
197 Expr::Resolved(_) | Expr::Literal(_) => {}
198 Expr::Attr(obj, _) => {
199 resolve_expr(obj, local_slots);
200 }
201 Expr::FnCall(func, args) => {
202 resolve_expr(func, local_slots);
203 for arg in args {
204 resolve_expr(arg, local_slots);
205 }
206 }
207 Expr::BinOp(_, left, right) => {
208 resolve_expr(left, local_slots);
209 resolve_expr(right, local_slots);
210 }
211 Expr::Match { subject, arms } => {
212 resolve_expr(subject, local_slots);
213 for arm in arms {
214 resolve_expr(&mut arm.body, local_slots);
215 }
216 }
217 Expr::Constructor(_, Some(inner)) => {
218 resolve_expr(inner, local_slots);
219 }
220 Expr::Constructor(_, None) => {}
221 Expr::ErrorProp(inner) => {
222 resolve_expr(inner, local_slots);
223 }
224 Expr::InterpolatedStr(parts) => {
225 for part in parts {
226 if let StrPart::Parsed(e) = part {
227 resolve_expr(e, local_slots);
228 }
229 }
230 }
231 Expr::List(elements) => {
232 for elem in elements {
233 resolve_expr(elem, local_slots);
234 }
235 }
236 Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
237 for item in items {
238 resolve_expr(item, local_slots);
239 }
240 }
241 Expr::MapLiteral(entries) => {
242 for (key, value) in entries {
243 resolve_expr(key, local_slots);
244 resolve_expr(value, local_slots);
245 }
246 }
247 Expr::RecordCreate { fields, .. } => {
248 for (_, expr) in fields {
249 resolve_expr(expr, local_slots);
250 }
251 }
252 Expr::RecordUpdate { base, updates, .. } => {
253 resolve_expr(base, local_slots);
254 for (_, expr) in updates {
255 resolve_expr(expr, local_slots);
256 }
257 }
258 Expr::TailCall(boxed) => {
259 for arg in &mut boxed.1 {
260 resolve_expr(arg, local_slots);
261 }
262 }
263 }
264}
265
266fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
268 for stmt in stmts {
269 match stmt {
270 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
271 resolve_expr(expr, local_slots);
272 }
273 }
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn resolves_param_to_slot() {
283 let mut fd = FnDef {
284 name: "add".to_string(),
285 line: 1,
286 params: vec![
287 ("a".to_string(), "Int".to_string()),
288 ("b".to_string(), "Int".to_string()),
289 ],
290 return_type: "Int".to_string(),
291 effects: vec![],
292 desc: None,
293 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::BinOp(
294 BinOp::Add,
295 Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
296 Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
297 )))),
298 resolution: None,
299 };
300 resolve_fn(&mut fd);
301 let res = fd.resolution.as_ref().unwrap();
302 assert_eq!(res.local_slots["a"], 0);
303 assert_eq!(res.local_slots["b"], 1);
304 assert_eq!(res.local_count, 2);
305
306 match fd.body.tail_expr() {
307 Some(Spanned {
308 node: Expr::BinOp(_, left, right),
309 ..
310 }) => {
311 assert_eq!(left.node, Expr::Resolved(0));
312 assert_eq!(right.node, Expr::Resolved(1));
313 }
314 other => panic!("unexpected body: {:?}", other),
315 }
316 }
317
318 #[test]
319 fn leaves_globals_as_ident() {
320 let mut fd = FnDef {
321 name: "f".to_string(),
322 line: 1,
323 params: vec![("x".to_string(), "Int".to_string())],
324 return_type: "Int".to_string(),
325 effects: vec![],
326 desc: None,
327 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::FnCall(
328 Box::new(Spanned::bare(Expr::Ident("Console".to_string()))),
329 vec![Spanned::bare(Expr::Ident("x".to_string()))],
330 )))),
331 resolution: None,
332 };
333 resolve_fn(&mut fd);
334 match fd.body.tail_expr() {
335 Some(Spanned {
336 node: Expr::FnCall(func, args),
337 ..
338 }) => {
339 assert_eq!(func.node, Expr::Ident("Console".to_string()));
340 assert_eq!(args[0].node, Expr::Resolved(0));
341 }
342 other => panic!("unexpected body: {:?}", other),
343 }
344 }
345
346 #[test]
347 fn resolves_val_in_block_body() {
348 let mut fd = FnDef {
349 name: "f".to_string(),
350 line: 1,
351 params: vec![("x".to_string(), "Int".to_string())],
352 return_type: "Int".to_string(),
353 effects: vec![],
354 desc: None,
355 body: Rc::new(FnBody::Block(vec![
356 Stmt::Binding(
357 "y".to_string(),
358 None,
359 Spanned::bare(Expr::BinOp(
360 BinOp::Add,
361 Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
362 Box::new(Spanned::bare(Expr::Literal(Literal::Int(1)))),
363 )),
364 ),
365 Stmt::Expr(Spanned::bare(Expr::Ident("y".to_string()))),
366 ])),
367 resolution: None,
368 };
369 resolve_fn(&mut fd);
370 let res = fd.resolution.as_ref().unwrap();
371 assert_eq!(res.local_slots["x"], 0);
372 assert_eq!(res.local_slots["y"], 1);
373 assert_eq!(res.local_count, 2);
374
375 let stmts = fd.body.stmts();
376 match &stmts[0] {
378 Stmt::Binding(
379 _,
380 _,
381 Spanned {
382 node: Expr::BinOp(_, left, _),
383 ..
384 },
385 ) => {
386 assert_eq!(left.node, Expr::Resolved(0));
387 }
388 other => panic!("unexpected stmt: {:?}", other),
389 }
390 match &stmts[1] {
392 Stmt::Expr(Spanned {
393 node: Expr::Resolved(1),
394 ..
395 }) => {}
396 other => panic!("unexpected stmt: {:?}", other),
397 }
398 }
399
400 #[test]
401 fn resolves_match_pattern_bindings() {
402 let mut fd = FnDef {
404 name: "f".to_string(),
405 line: 1,
406 params: vec![("x".to_string(), "Int".to_string())],
407 return_type: "Int".to_string(),
408 effects: vec![],
409 desc: None,
410 body: Rc::new(FnBody::from_expr(Spanned::new(
411 Expr::Match {
412 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
413 arms: vec![
414 MatchArm {
415 pattern: Pattern::Constructor(
416 "Result.Ok".to_string(),
417 vec!["v".to_string()],
418 ),
419 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
420 },
421 MatchArm {
422 pattern: Pattern::Wildcard,
423 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
424 },
425 ],
426 },
427 1,
428 ))),
429 resolution: None,
430 };
431 resolve_fn(&mut fd);
432 let res = fd.resolution.as_ref().unwrap();
433 assert_eq!(res.local_slots["v"], 1);
435
436 match fd.body.tail_expr() {
437 Some(Spanned {
438 node: Expr::Match { arms, .. },
439 ..
440 }) => {
441 assert_eq!(arms[0].body.node, Expr::Resolved(1));
442 }
443 other => panic!("unexpected body: {:?}", other),
444 }
445 }
446}