1use std::collections::HashMap;
13use std::rc::Rc;
14
15use crate::ast::*;
16
17pub fn resolve_program(items: &mut Vec<TopLevel>) {
19 for item in items.iter_mut() {
20 if let TopLevel::FnDef(fd) = item {
21 resolve_fn(fd);
22 }
23 }
24}
25
26fn resolve_fn(fd: &mut FnDef) {
28 let mut local_slots: HashMap<String, u16> = HashMap::new();
29 let mut next_slot: u16 = 0;
30
31 for (param_name, _) in &fd.params {
33 local_slots.insert(param_name.clone(), next_slot);
34 next_slot += 1;
35 }
36
37 match fd.body.as_ref() {
39 FnBody::Expr(expr) => {
40 collect_expr_bindings(expr, &mut local_slots, &mut next_slot);
41 }
42 FnBody::Block(stmts) => {
43 collect_binding_slots(stmts, &mut local_slots, &mut next_slot);
44 }
45 }
46
47 let mut body = fd.body.as_ref().clone();
49 match &mut body {
50 FnBody::Expr(expr) => {
51 resolve_expr(expr, &local_slots);
52 }
53 FnBody::Block(stmts) => {
54 resolve_stmts(stmts, &local_slots);
55 }
56 }
57 fd.body = Rc::new(body);
58
59 fd.resolution = Some(FnResolution {
60 local_count: next_slot,
61 local_slots,
62 });
63}
64
65fn collect_binding_slots(
68 stmts: &[Stmt],
69 local_slots: &mut HashMap<String, u16>,
70 next_slot: &mut u16,
71) {
72 for stmt in stmts {
73 match stmt {
74 Stmt::Binding(name, _, _) => {
75 if !local_slots.contains_key(name) {
76 local_slots.insert(name.clone(), *next_slot);
77 *next_slot += 1;
78 }
79 }
80 Stmt::Expr(expr) => {
81 collect_expr_bindings(expr, local_slots, next_slot);
82 }
83 }
84 }
85}
86
87fn collect_expr_bindings(expr: &Expr, local_slots: &mut HashMap<String, u16>, next_slot: &mut u16) {
89 match expr {
90 Expr::Match { subject, arms, .. } => {
91 collect_expr_bindings(subject, local_slots, next_slot);
92 for arm in arms {
93 collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
94 collect_expr_bindings(&arm.body, local_slots, next_slot);
95 }
96 }
97 Expr::BinOp(_, left, right) => {
98 collect_expr_bindings(left, local_slots, next_slot);
99 collect_expr_bindings(right, local_slots, next_slot);
100 }
101 Expr::FnCall(func, args) => {
102 collect_expr_bindings(func, local_slots, next_slot);
103 for arg in args {
104 collect_expr_bindings(arg, local_slots, next_slot);
105 }
106 }
107 Expr::Pipe(left, right) => {
108 collect_expr_bindings(left, local_slots, next_slot);
109 collect_expr_bindings(right, local_slots, next_slot);
110 }
111 Expr::ErrorProp(inner) => {
112 collect_expr_bindings(inner, local_slots, next_slot);
113 }
114 Expr::Constructor(_, Some(inner)) => {
115 collect_expr_bindings(inner, local_slots, next_slot);
116 }
117 Expr::List(elements) => {
118 for elem in elements {
119 collect_expr_bindings(elem, local_slots, next_slot);
120 }
121 }
122 Expr::Tuple(items) => {
123 for item in items {
124 collect_expr_bindings(item, local_slots, next_slot);
125 }
126 }
127 Expr::MapLiteral(entries) => {
128 for (key, value) in entries {
129 collect_expr_bindings(key, local_slots, next_slot);
130 collect_expr_bindings(value, local_slots, next_slot);
131 }
132 }
133 Expr::InterpolatedStr(parts) => {
134 for part in parts {
135 if let StrPart::Parsed(e) = part {
136 collect_expr_bindings(e, local_slots, next_slot);
137 }
138 }
139 }
140 Expr::RecordCreate { fields, .. } => {
141 for (_, expr) in fields {
142 collect_expr_bindings(expr, local_slots, next_slot);
143 }
144 }
145 Expr::RecordUpdate { base, updates, .. } => {
146 collect_expr_bindings(base, local_slots, next_slot);
147 for (_, expr) in updates {
148 collect_expr_bindings(expr, local_slots, next_slot);
149 }
150 }
151 Expr::Attr(obj, _) => {
152 collect_expr_bindings(obj, local_slots, next_slot);
153 }
154 Expr::TailCall(boxed) => {
155 for arg in &boxed.1 {
156 collect_expr_bindings(arg, local_slots, next_slot);
157 }
158 }
159 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
161 }
162}
163
164fn collect_pattern_bindings(
166 pattern: &Pattern,
167 local_slots: &mut HashMap<String, u16>,
168 next_slot: &mut u16,
169) {
170 match pattern {
171 Pattern::Ident(name) => {
172 if !local_slots.contains_key(name) {
173 local_slots.insert(name.clone(), *next_slot);
174 *next_slot += 1;
175 }
176 }
177 Pattern::Cons(head, tail) => {
178 for name in [head, tail] {
179 if name != "_" && !local_slots.contains_key(name) {
180 local_slots.insert(name.clone(), *next_slot);
181 *next_slot += 1;
182 }
183 }
184 }
185 Pattern::Constructor(_, bindings) => {
186 for name in bindings {
187 if name != "_" && !local_slots.contains_key(name) {
188 local_slots.insert(name.clone(), *next_slot);
189 *next_slot += 1;
190 }
191 }
192 }
193 Pattern::Tuple(items) => {
194 for item in items {
195 collect_pattern_bindings(item, local_slots, next_slot);
196 }
197 }
198 Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
199 }
200}
201
202fn resolve_expr(expr: &mut Expr, local_slots: &HashMap<String, u16>) {
204 match expr {
205 Expr::Ident(name) => {
206 if let Some(&slot) = local_slots.get(name) {
207 *expr = Expr::Resolved(slot);
208 }
209 }
211 Expr::Resolved(_) | Expr::Literal(_) => {}
212 Expr::Attr(obj, _) => {
213 resolve_expr(obj, local_slots);
214 }
215 Expr::FnCall(func, args) => {
216 resolve_expr(func, local_slots);
217 for arg in args {
218 resolve_expr(arg, local_slots);
219 }
220 }
221 Expr::BinOp(_, left, right) => {
222 resolve_expr(left, local_slots);
223 resolve_expr(right, local_slots);
224 }
225 Expr::Match { subject, arms, .. } => {
226 resolve_expr(subject, local_slots);
227 for arm in arms {
228 resolve_expr(&mut arm.body, local_slots);
229 }
230 }
231 Expr::Pipe(left, right) => {
232 resolve_expr(left, local_slots);
233 resolve_expr(right, local_slots);
234 }
235 Expr::Constructor(_, Some(inner)) => {
236 resolve_expr(inner, local_slots);
237 }
238 Expr::Constructor(_, None) => {}
239 Expr::ErrorProp(inner) => {
240 resolve_expr(inner, local_slots);
241 }
242 Expr::InterpolatedStr(parts) => {
243 for part in parts {
244 if let StrPart::Parsed(e) = part {
245 resolve_expr(e, local_slots);
246 }
247 }
248 }
249 Expr::List(elements) => {
250 for elem in elements {
251 resolve_expr(elem, local_slots);
252 }
253 }
254 Expr::Tuple(items) => {
255 for item in items {
256 resolve_expr(item, local_slots);
257 }
258 }
259 Expr::MapLiteral(entries) => {
260 for (key, value) in entries {
261 resolve_expr(key, local_slots);
262 resolve_expr(value, local_slots);
263 }
264 }
265 Expr::RecordCreate { fields, .. } => {
266 for (_, expr) in fields {
267 resolve_expr(expr, local_slots);
268 }
269 }
270 Expr::RecordUpdate { base, updates, .. } => {
271 resolve_expr(base, local_slots);
272 for (_, expr) in updates {
273 resolve_expr(expr, local_slots);
274 }
275 }
276 Expr::TailCall(boxed) => {
277 for arg in &mut boxed.1 {
278 resolve_expr(arg, local_slots);
279 }
280 }
281 }
282}
283
284fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
286 for stmt in stmts {
287 match stmt {
288 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
289 resolve_expr(expr, local_slots);
290 }
291 }
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn resolves_param_to_slot() {
301 let mut fd = FnDef {
302 name: "add".to_string(),
303 line: 1,
304 params: vec![
305 ("a".to_string(), "Int".to_string()),
306 ("b".to_string(), "Int".to_string()),
307 ],
308 return_type: "Int".to_string(),
309 effects: vec![],
310 desc: None,
311 body: Rc::new(FnBody::Expr(Expr::BinOp(
312 BinOp::Add,
313 Box::new(Expr::Ident("a".to_string())),
314 Box::new(Expr::Ident("b".to_string())),
315 ))),
316 resolution: None,
317 };
318 resolve_fn(&mut fd);
319 let res = fd.resolution.as_ref().unwrap();
320 assert_eq!(res.local_slots["a"], 0);
321 assert_eq!(res.local_slots["b"], 1);
322 assert_eq!(res.local_count, 2);
323
324 match fd.body.as_ref() {
325 FnBody::Expr(Expr::BinOp(_, left, right)) => {
326 assert_eq!(**left, Expr::Resolved(0));
327 assert_eq!(**right, Expr::Resolved(1));
328 }
329 other => panic!("unexpected body: {:?}", other),
330 }
331 }
332
333 #[test]
334 fn leaves_globals_as_ident() {
335 let mut fd = FnDef {
336 name: "f".to_string(),
337 line: 1,
338 params: vec![("x".to_string(), "Int".to_string())],
339 return_type: "Int".to_string(),
340 effects: vec![],
341 desc: None,
342 body: Rc::new(FnBody::Expr(Expr::FnCall(
343 Box::new(Expr::Ident("Console".to_string())),
344 vec![Expr::Ident("x".to_string())],
345 ))),
346 resolution: None,
347 };
348 resolve_fn(&mut fd);
349 match fd.body.as_ref() {
350 FnBody::Expr(Expr::FnCall(func, args)) => {
351 assert_eq!(**func, Expr::Ident("Console".to_string()));
353 assert_eq!(args[0], Expr::Resolved(0));
355 }
356 other => panic!("unexpected body: {:?}", other),
357 }
358 }
359
360 #[test]
361 fn resolves_val_in_block_body() {
362 let mut fd = FnDef {
363 name: "f".to_string(),
364 line: 1,
365 params: vec![("x".to_string(), "Int".to_string())],
366 return_type: "Int".to_string(),
367 effects: vec![],
368 desc: None,
369 body: Rc::new(FnBody::Block(vec![
370 Stmt::Binding(
371 "y".to_string(),
372 None,
373 Expr::BinOp(
374 BinOp::Add,
375 Box::new(Expr::Ident("x".to_string())),
376 Box::new(Expr::Literal(Literal::Int(1))),
377 ),
378 ),
379 Stmt::Expr(Expr::Ident("y".to_string())),
380 ])),
381 resolution: None,
382 };
383 resolve_fn(&mut fd);
384 let res = fd.resolution.as_ref().unwrap();
385 assert_eq!(res.local_slots["x"], 0);
386 assert_eq!(res.local_slots["y"], 1);
387 assert_eq!(res.local_count, 2);
388
389 match fd.body.as_ref() {
390 FnBody::Block(stmts) => {
391 match &stmts[0] {
393 Stmt::Binding(_, _, Expr::BinOp(_, left, _)) => {
394 assert_eq!(**left, Expr::Resolved(0));
395 }
396 other => panic!("unexpected stmt: {:?}", other),
397 }
398 match &stmts[1] {
400 Stmt::Expr(Expr::Resolved(1)) => {}
401 other => panic!("unexpected stmt: {:?}", other),
402 }
403 }
404 other => panic!("unexpected body: {:?}", other),
405 }
406 }
407
408 #[test]
409 fn resolves_match_pattern_bindings() {
410 let mut fd = FnDef {
412 name: "f".to_string(),
413 line: 1,
414 params: vec![("x".to_string(), "Int".to_string())],
415 return_type: "Int".to_string(),
416 effects: vec![],
417 desc: None,
418 body: Rc::new(FnBody::Expr(Expr::Match {
419 subject: Box::new(Expr::Ident("x".to_string())),
420 arms: vec![
421 MatchArm {
422 pattern: Pattern::Constructor(
423 "Result.Ok".to_string(),
424 vec!["v".to_string()],
425 ),
426 body: Box::new(Expr::Ident("v".to_string())),
427 },
428 MatchArm {
429 pattern: Pattern::Wildcard,
430 body: Box::new(Expr::Literal(Literal::Int(0))),
431 },
432 ],
433 line: 1,
434 })),
435 resolution: None,
436 };
437 resolve_fn(&mut fd);
438 let res = fd.resolution.as_ref().unwrap();
439 assert_eq!(res.local_slots["v"], 1);
441
442 match fd.body.as_ref() {
443 FnBody::Expr(Expr::Match { arms, .. }) => {
444 assert_eq!(*arms[0].body, Expr::Resolved(1));
446 }
447 other => panic!("unexpected body: {:?}", other),
448 }
449 }
450}