1use std::collections::HashMap;
13use std::rc::Rc;
14
15use crate::ast::*;
16
17pub fn resolve_program(items: &mut [TopLevel]) {
19 for item in items.iter_mut() {
20 if let TopLevel::FnDef(fd) = item {
21 resolve_fn(fd);
22 }
23 }
24}
25
26fn resolve_fn(fd: &mut FnDef) {
28 let mut local_slots: HashMap<String, u16> = HashMap::new();
29 let mut next_slot: u16 = 0;
30
31 for (param_name, _) in &fd.params {
33 local_slots.insert(param_name.clone(), next_slot);
34 next_slot += 1;
35 }
36
37 match fd.body.as_ref() {
39 FnBody::Expr(expr) => {
40 collect_expr_bindings(expr, &mut local_slots, &mut next_slot);
41 }
42 FnBody::Block(stmts) => {
43 collect_binding_slots(stmts, &mut local_slots, &mut next_slot);
44 }
45 }
46
47 let mut body = fd.body.as_ref().clone();
49 match &mut body {
50 FnBody::Expr(expr) => {
51 resolve_expr(expr, &local_slots);
52 }
53 FnBody::Block(stmts) => {
54 resolve_stmts(stmts, &local_slots);
55 }
56 }
57 fd.body = Rc::new(body);
58
59 fd.resolution = Some(FnResolution {
60 local_count: next_slot,
61 local_slots,
62 });
63}
64
65fn collect_binding_slots(
68 stmts: &[Stmt],
69 local_slots: &mut HashMap<String, u16>,
70 next_slot: &mut u16,
71) {
72 for stmt in stmts {
73 match stmt {
74 Stmt::Binding(name, _, _) => {
75 if !local_slots.contains_key(name) {
76 local_slots.insert(name.clone(), *next_slot);
77 *next_slot += 1;
78 }
79 }
80 Stmt::Expr(expr) => {
81 collect_expr_bindings(expr, local_slots, next_slot);
82 }
83 }
84 }
85}
86
87fn collect_expr_bindings(expr: &Expr, local_slots: &mut HashMap<String, u16>, next_slot: &mut u16) {
89 match expr {
90 Expr::Match { subject, arms, .. } => {
91 collect_expr_bindings(subject, local_slots, next_slot);
92 for arm in arms {
93 collect_pattern_bindings(&arm.pattern, local_slots, next_slot);
94 collect_expr_bindings(&arm.body, local_slots, next_slot);
95 }
96 }
97 Expr::BinOp(_, left, right) => {
98 collect_expr_bindings(left, local_slots, next_slot);
99 collect_expr_bindings(right, local_slots, next_slot);
100 }
101 Expr::FnCall(func, args) => {
102 collect_expr_bindings(func, local_slots, next_slot);
103 for arg in args {
104 collect_expr_bindings(arg, local_slots, next_slot);
105 }
106 }
107 Expr::ErrorProp(inner) => {
108 collect_expr_bindings(inner, local_slots, next_slot);
109 }
110 Expr::Constructor(_, Some(inner)) => {
111 collect_expr_bindings(inner, local_slots, next_slot);
112 }
113 Expr::List(elements) => {
114 for elem in elements {
115 collect_expr_bindings(elem, local_slots, next_slot);
116 }
117 }
118 Expr::Tuple(items) => {
119 for item in items {
120 collect_expr_bindings(item, local_slots, next_slot);
121 }
122 }
123 Expr::MapLiteral(entries) => {
124 for (key, value) in entries {
125 collect_expr_bindings(key, local_slots, next_slot);
126 collect_expr_bindings(value, local_slots, next_slot);
127 }
128 }
129 Expr::InterpolatedStr(parts) => {
130 for part in parts {
131 if let StrPart::Parsed(e) = part {
132 collect_expr_bindings(e, local_slots, next_slot);
133 }
134 }
135 }
136 Expr::RecordCreate { fields, .. } => {
137 for (_, expr) in fields {
138 collect_expr_bindings(expr, local_slots, next_slot);
139 }
140 }
141 Expr::RecordUpdate { base, updates, .. } => {
142 collect_expr_bindings(base, local_slots, next_slot);
143 for (_, expr) in updates {
144 collect_expr_bindings(expr, local_slots, next_slot);
145 }
146 }
147 Expr::Attr(obj, _) => {
148 collect_expr_bindings(obj, local_slots, next_slot);
149 }
150 Expr::TailCall(boxed) => {
151 for arg in &boxed.1 {
152 collect_expr_bindings(arg, local_slots, next_slot);
153 }
154 }
155 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) | Expr::Constructor(_, None) => {}
157 }
158}
159
160fn collect_pattern_bindings(
162 pattern: &Pattern,
163 local_slots: &mut HashMap<String, u16>,
164 next_slot: &mut u16,
165) {
166 match pattern {
167 Pattern::Ident(name) => {
168 if !local_slots.contains_key(name) {
169 local_slots.insert(name.clone(), *next_slot);
170 *next_slot += 1;
171 }
172 }
173 Pattern::Cons(head, tail) => {
174 for name in [head, tail] {
175 if name != "_" && !local_slots.contains_key(name) {
176 local_slots.insert(name.clone(), *next_slot);
177 *next_slot += 1;
178 }
179 }
180 }
181 Pattern::Constructor(_, bindings) => {
182 for name in bindings {
183 if name != "_" && !local_slots.contains_key(name) {
184 local_slots.insert(name.clone(), *next_slot);
185 *next_slot += 1;
186 }
187 }
188 }
189 Pattern::Tuple(items) => {
190 for item in items {
191 collect_pattern_bindings(item, local_slots, next_slot);
192 }
193 }
194 Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => {}
195 }
196}
197
198fn resolve_expr(expr: &mut Expr, local_slots: &HashMap<String, u16>) {
200 match expr {
201 Expr::Ident(name) => {
202 if let Some(&slot) = local_slots.get(name) {
203 *expr = Expr::Resolved(slot);
204 }
205 }
207 Expr::Resolved(_) | Expr::Literal(_) => {}
208 Expr::Attr(obj, _) => {
209 resolve_expr(obj, local_slots);
210 }
211 Expr::FnCall(func, args) => {
212 resolve_expr(func, local_slots);
213 for arg in args {
214 resolve_expr(arg, local_slots);
215 }
216 }
217 Expr::BinOp(_, left, right) => {
218 resolve_expr(left, local_slots);
219 resolve_expr(right, local_slots);
220 }
221 Expr::Match { subject, arms, .. } => {
222 resolve_expr(subject, local_slots);
223 for arm in arms {
224 resolve_expr(&mut arm.body, local_slots);
225 }
226 }
227 Expr::Constructor(_, Some(inner)) => {
228 resolve_expr(inner, local_slots);
229 }
230 Expr::Constructor(_, None) => {}
231 Expr::ErrorProp(inner) => {
232 resolve_expr(inner, local_slots);
233 }
234 Expr::InterpolatedStr(parts) => {
235 for part in parts {
236 if let StrPart::Parsed(e) = part {
237 resolve_expr(e, local_slots);
238 }
239 }
240 }
241 Expr::List(elements) => {
242 for elem in elements {
243 resolve_expr(elem, local_slots);
244 }
245 }
246 Expr::Tuple(items) => {
247 for item in items {
248 resolve_expr(item, local_slots);
249 }
250 }
251 Expr::MapLiteral(entries) => {
252 for (key, value) in entries {
253 resolve_expr(key, local_slots);
254 resolve_expr(value, local_slots);
255 }
256 }
257 Expr::RecordCreate { fields, .. } => {
258 for (_, expr) in fields {
259 resolve_expr(expr, local_slots);
260 }
261 }
262 Expr::RecordUpdate { base, updates, .. } => {
263 resolve_expr(base, local_slots);
264 for (_, expr) in updates {
265 resolve_expr(expr, local_slots);
266 }
267 }
268 Expr::TailCall(boxed) => {
269 for arg in &mut boxed.1 {
270 resolve_expr(arg, local_slots);
271 }
272 }
273 }
274}
275
276fn resolve_stmts(stmts: &mut [Stmt], local_slots: &HashMap<String, u16>) {
278 for stmt in stmts {
279 match stmt {
280 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
281 resolve_expr(expr, local_slots);
282 }
283 }
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn resolves_param_to_slot() {
293 let mut fd = FnDef {
294 name: "add".to_string(),
295 line: 1,
296 params: vec![
297 ("a".to_string(), "Int".to_string()),
298 ("b".to_string(), "Int".to_string()),
299 ],
300 return_type: "Int".to_string(),
301 effects: vec![],
302 desc: None,
303 body: Rc::new(FnBody::Expr(Expr::BinOp(
304 BinOp::Add,
305 Box::new(Expr::Ident("a".to_string())),
306 Box::new(Expr::Ident("b".to_string())),
307 ))),
308 resolution: None,
309 };
310 resolve_fn(&mut fd);
311 let res = fd.resolution.as_ref().unwrap();
312 assert_eq!(res.local_slots["a"], 0);
313 assert_eq!(res.local_slots["b"], 1);
314 assert_eq!(res.local_count, 2);
315
316 match fd.body.as_ref() {
317 FnBody::Expr(Expr::BinOp(_, left, right)) => {
318 assert_eq!(**left, Expr::Resolved(0));
319 assert_eq!(**right, Expr::Resolved(1));
320 }
321 other => panic!("unexpected body: {:?}", other),
322 }
323 }
324
325 #[test]
326 fn leaves_globals_as_ident() {
327 let mut fd = FnDef {
328 name: "f".to_string(),
329 line: 1,
330 params: vec![("x".to_string(), "Int".to_string())],
331 return_type: "Int".to_string(),
332 effects: vec![],
333 desc: None,
334 body: Rc::new(FnBody::Expr(Expr::FnCall(
335 Box::new(Expr::Ident("Console".to_string())),
336 vec![Expr::Ident("x".to_string())],
337 ))),
338 resolution: None,
339 };
340 resolve_fn(&mut fd);
341 match fd.body.as_ref() {
342 FnBody::Expr(Expr::FnCall(func, args)) => {
343 assert_eq!(**func, Expr::Ident("Console".to_string()));
345 assert_eq!(args[0], Expr::Resolved(0));
347 }
348 other => panic!("unexpected body: {:?}", other),
349 }
350 }
351
352 #[test]
353 fn resolves_val_in_block_body() {
354 let mut fd = FnDef {
355 name: "f".to_string(),
356 line: 1,
357 params: vec![("x".to_string(), "Int".to_string())],
358 return_type: "Int".to_string(),
359 effects: vec![],
360 desc: None,
361 body: Rc::new(FnBody::Block(vec![
362 Stmt::Binding(
363 "y".to_string(),
364 None,
365 Expr::BinOp(
366 BinOp::Add,
367 Box::new(Expr::Ident("x".to_string())),
368 Box::new(Expr::Literal(Literal::Int(1))),
369 ),
370 ),
371 Stmt::Expr(Expr::Ident("y".to_string())),
372 ])),
373 resolution: None,
374 };
375 resolve_fn(&mut fd);
376 let res = fd.resolution.as_ref().unwrap();
377 assert_eq!(res.local_slots["x"], 0);
378 assert_eq!(res.local_slots["y"], 1);
379 assert_eq!(res.local_count, 2);
380
381 match fd.body.as_ref() {
382 FnBody::Block(stmts) => {
383 match &stmts[0] {
385 Stmt::Binding(_, _, Expr::BinOp(_, left, _)) => {
386 assert_eq!(**left, Expr::Resolved(0));
387 }
388 other => panic!("unexpected stmt: {:?}", other),
389 }
390 match &stmts[1] {
392 Stmt::Expr(Expr::Resolved(1)) => {}
393 other => panic!("unexpected stmt: {:?}", other),
394 }
395 }
396 other => panic!("unexpected body: {:?}", other),
397 }
398 }
399
400 #[test]
401 fn resolves_match_pattern_bindings() {
402 let mut fd = FnDef {
404 name: "f".to_string(),
405 line: 1,
406 params: vec![("x".to_string(), "Int".to_string())],
407 return_type: "Int".to_string(),
408 effects: vec![],
409 desc: None,
410 body: Rc::new(FnBody::Expr(Expr::Match {
411 subject: Box::new(Expr::Ident("x".to_string())),
412 arms: vec![
413 MatchArm {
414 pattern: Pattern::Constructor(
415 "Result.Ok".to_string(),
416 vec!["v".to_string()],
417 ),
418 body: Box::new(Expr::Ident("v".to_string())),
419 },
420 MatchArm {
421 pattern: Pattern::Wildcard,
422 body: Box::new(Expr::Literal(Literal::Int(0))),
423 },
424 ],
425 line: 1,
426 })),
427 resolution: None,
428 };
429 resolve_fn(&mut fd);
430 let res = fd.resolution.as_ref().unwrap();
431 assert_eq!(res.local_slots["v"], 1);
433
434 match fd.body.as_ref() {
435 FnBody::Expr(Expr::Match { arms, .. }) => {
436 assert_eq!(*arms[0].body, Expr::Resolved(1));
438 }
439 other => panic!("unexpected body: {:?}", other),
440 }
441 }
442}