1use std::collections::HashMap;
13use std::sync::Arc as Rc;
14
15use crate::ast::*;
16
17struct TypeInfo {
22 variants: HashMap<(String, String), Vec<String>>,
31 variant_parents: HashMap<String, Vec<String>>,
37 #[allow(dead_code)]
42 records: HashMap<String, Vec<(String, String)>>,
43}
44
45fn build_type_info(items: &[TopLevel]) -> TypeInfo {
46 let mut variants: HashMap<(String, String), Vec<String>> = HashMap::new();
47 let mut variant_parents: HashMap<String, Vec<String>> = HashMap::new();
48 let mut records: HashMap<String, Vec<(String, String)>> = HashMap::new();
49 for item in items {
50 match item {
51 TopLevel::TypeDef(TypeDef::Sum {
52 name: parent,
53 variants: vs,
54 ..
55 }) => {
56 for v in vs {
57 variants.insert((parent.clone(), v.name.clone()), v.fields.clone());
58 variant_parents
59 .entry(v.name.clone())
60 .or_default()
61 .push(parent.clone());
62 }
63 }
64 TopLevel::TypeDef(TypeDef::Product { name, fields, .. }) => {
65 records.insert(name.clone(), fields.clone());
66 }
67 _ => {}
68 }
69 }
70 TypeInfo {
71 variants,
72 variant_parents,
73 records,
74 }
75}
76
77pub fn resolve_program(items: &mut [TopLevel]) {
82 let type_info = build_type_info(items);
83 for item in items.iter_mut() {
84 if let TopLevel::FnDef(fd) = item {
85 resolve_fn(fd, &type_info);
86 }
87 }
88}
89
90fn resolve_fn(fd: &mut FnDef, type_info: &TypeInfo) {
101 let mut state = ResolverState::new(type_info);
102
103 state.scopes.push(HashMap::new());
105 for (param_name, ty_str) in &fd.params {
106 let ty = crate::types::parse_type_str_strict(ty_str).unwrap_or(Type::Invalid);
107 let slot = state.declare(param_name, ty);
108 state.last_alloc.insert(param_name.clone(), slot);
109 }
110
111 let mut body = fd.body.as_ref().clone();
115 state.walk_stmts(body.stmts_mut());
116 fd.body = Rc::new(body);
117
118 let next_slot = state.next_slot;
119 let last_alloc = state.last_alloc;
120 let slot_types = state.slot_types;
121 state.scopes.pop();
122
123 fd.resolution = Some(FnResolution {
124 local_count: next_slot,
125 local_slots: Rc::new(last_alloc),
126 local_slot_types: Rc::new(slot_types),
127 aliased_slots: Rc::new(vec![false; next_slot as usize]),
128 });
129}
130
131struct ResolverState<'a> {
135 type_info: &'a TypeInfo,
136 next_slot: u16,
137 slot_types: Vec<Type>,
138 scopes: Vec<HashMap<String, u16>>,
143 last_alloc: HashMap<String, u16>,
150}
151
152impl<'a> ResolverState<'a> {
153 fn new(type_info: &'a TypeInfo) -> Self {
154 Self {
155 type_info,
156 next_slot: 0,
157 slot_types: Vec::new(),
158 scopes: Vec::new(),
159 last_alloc: HashMap::new(),
160 }
161 }
162
163 fn alloc(&mut self, ty: Type) -> u16 {
165 let idx = self.next_slot;
166 self.next_slot += 1;
167 self.slot_types.push(ty);
168 idx
169 }
170
171 fn declare(&mut self, name: &str, ty: Type) -> u16 {
185 if name == "_" {
186 return u16::MAX;
187 }
188 let slot = self.alloc(ty);
189 if let Some(scope) = self.scopes.last_mut() {
190 scope.insert(name.to_string(), slot);
191 }
192 self.last_alloc.insert(name.to_string(), slot);
193 slot
194 }
195
196 fn lookup(&self, name: &str) -> Option<u16> {
200 for scope in self.scopes.iter().rev() {
201 if let Some(&s) = scope.get(name) {
202 return Some(s);
203 }
204 }
205 None
206 }
207
208 fn walk_stmts(&mut self, stmts: &mut [Stmt]) {
209 for stmt in stmts {
210 match stmt {
211 Stmt::Binding(name, _annot, expr) => {
212 let ty = expr.ty().cloned().unwrap_or(Type::Invalid);
219 self.declare(name, ty);
220 self.walk_expr(expr);
221 }
222 Stmt::Expr(expr) => self.walk_expr(expr),
223 }
224 }
225 }
226
227 fn walk_expr(&mut self, expr: &mut Spanned<Expr>) {
228 match &mut expr.node {
229 Expr::Ident(name) => {
230 if let Some(slot) = self.lookup(name) {
231 expr.node = Expr::Resolved {
232 slot,
233 name: name.clone(),
234 last_use: AnnotBool(false),
235 };
236 }
237 }
238 Expr::Match { subject, arms } => {
239 self.walk_expr(subject);
240 let subject_ty = subject.ty().cloned();
241 for arm in arms.iter_mut() {
242 self.scopes.push(HashMap::new());
243 let slots = self.allocate_pattern(&arm.pattern, subject_ty.as_ref());
244 let _ = arm.binding_slots.set(slots);
245 self.walk_expr(&mut arm.body);
246 self.scopes.pop();
247 }
248 }
249 Expr::FnCall(func, args) => {
250 self.walk_expr(func);
251 for arg in args {
252 self.walk_expr(arg);
253 }
254 }
255 Expr::BinOp(_, l, r) => {
256 self.walk_expr(l);
257 self.walk_expr(r);
258 }
259 Expr::Attr(obj, _) => self.walk_expr(obj),
260 Expr::ErrorProp(inner) => self.walk_expr(inner),
261 Expr::Constructor(_, Some(inner)) => self.walk_expr(inner),
262 Expr::Constructor(_, None) => {}
263 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
264 for it in items {
265 self.walk_expr(it);
266 }
267 }
268 Expr::MapLiteral(entries) => {
269 for (k, v) in entries {
270 self.walk_expr(k);
271 self.walk_expr(v);
272 }
273 }
274 Expr::InterpolatedStr(parts) => {
275 for part in parts {
276 if let StrPart::Parsed(e) = part {
277 self.walk_expr(e);
278 }
279 }
280 }
281 Expr::RecordCreate { fields, .. } => {
282 for (_, e) in fields {
283 self.walk_expr(e);
284 }
285 }
286 Expr::RecordUpdate { base, updates, .. } => {
287 self.walk_expr(base);
288 for (_, e) in updates {
289 self.walk_expr(e);
290 }
291 }
292 Expr::TailCall(boxed) => {
293 for a in &mut boxed.args {
294 self.walk_expr(a);
295 }
296 }
297 Expr::Literal(_) | Expr::Resolved { .. } => {}
298 }
299 }
300
301 fn allocate_pattern(&mut self, pattern: &Pattern, subject_ty: Option<&Type>) -> Vec<u16> {
307 match pattern {
308 Pattern::Ident(name) => {
309 let ty = subject_ty.cloned().unwrap_or(Type::Invalid);
310 vec![self.declare(name, ty)]
311 }
312 Pattern::Cons(head, tail) => {
313 let elem_ty = match subject_ty {
314 Some(Type::List(inner)) => (**inner).clone(),
315 _ => Type::Invalid,
316 };
317 let list_ty = Type::List(Box::new(elem_ty.clone()));
318 vec![self.declare(head, elem_ty), self.declare(tail, list_ty)]
319 }
320 Pattern::Constructor(name, bindings) => {
321 let bare = name.rsplit('.').next().unwrap_or(name);
322 let parent_hint: Option<String> = match (subject_ty, name.split_once('.')) {
323 (Some(Type::Named(parent)), _) => Some(parent.clone()),
324 (_, Some((parent, _))) => Some(parent.to_string()),
325 _ => self
326 .type_info
327 .variant_parents
328 .get(bare)
329 .and_then(|parents| {
330 if parents.len() == 1 {
331 Some(parents[0].clone())
332 } else {
333 None
334 }
335 }),
336 };
337 let field_tys: Vec<Type> = match (bare, subject_ty) {
338 ("Ok", Some(Type::Result(t, _))) => vec![(**t).clone()],
339 ("Err", Some(Type::Result(_, e))) => vec![(**e).clone()],
340 ("Some", Some(Type::Option(inner))) => vec![(**inner).clone()],
341 ("None", _) => Vec::new(),
342 _ => parent_hint
343 .and_then(|p| self.type_info.variants.get(&(p, bare.to_string())))
344 .map(|fields| {
345 fields
346 .iter()
347 .map(|s| {
348 crate::types::parse_type_str_strict(s).unwrap_or(Type::Invalid)
349 })
350 .collect()
351 })
352 .unwrap_or_else(|| vec![Type::Invalid; bindings.len()]),
353 };
354 bindings
355 .iter()
356 .enumerate()
357 .map(|(i, name)| {
358 let ty = field_tys.get(i).cloned().unwrap_or(Type::Invalid);
359 self.declare(name, ty)
360 })
361 .collect()
362 }
363 Pattern::Tuple(items) => {
364 let elem_tys: Vec<Type> = match subject_ty {
365 Some(Type::Tuple(elems)) if elems.len() == items.len() => elems.to_vec(),
366 _ => vec![Type::Invalid; items.len()],
367 };
368 let mut slots = Vec::new();
369 for (item, elem_ty) in items.iter().zip(elem_tys.iter()) {
370 slots.extend(self.allocate_pattern(item, Some(elem_ty)));
371 }
372 slots
373 }
374 Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => Vec::new(),
375 }
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn resolves_param_to_slot() {
385 let mut fd = FnDef {
386 name: "add".to_string(),
387 line: 1,
388 params: vec![
389 ("a".to_string(), "Int".to_string()),
390 ("b".to_string(), "Int".to_string()),
391 ],
392 return_type: "Int".to_string(),
393 effects: vec![],
394 desc: None,
395 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::BinOp(
396 BinOp::Add,
397 Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
398 Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
399 )))),
400 resolution: None,
401 };
402 resolve_fn(
403 &mut fd,
404 &TypeInfo {
405 variants: HashMap::new(),
406 variant_parents: HashMap::new(),
407 records: HashMap::new(),
408 },
409 );
410 let res = fd.resolution.as_ref().unwrap();
411 assert_eq!(res.local_slots["a"], 0);
412 assert_eq!(res.local_slots["b"], 1);
413 assert_eq!(res.local_count, 2);
414
415 match fd.body.tail_expr() {
416 Some(Spanned {
417 node: Expr::BinOp(_, left, right),
418 ..
419 }) => {
420 assert_eq!(
421 left.node,
422 Expr::Resolved {
423 slot: 0,
424 name: "a".to_string(),
425 last_use: AnnotBool(false)
426 }
427 );
428 assert_eq!(
429 right.node,
430 Expr::Resolved {
431 slot: 1,
432 name: "b".to_string(),
433 last_use: AnnotBool(false)
434 }
435 );
436 }
437 other => panic!("unexpected body: {:?}", other),
438 }
439 }
440
441 #[test]
442 fn leaves_globals_as_ident() {
443 let mut fd = FnDef {
444 name: "f".to_string(),
445 line: 1,
446 params: vec![("x".to_string(), "Int".to_string())],
447 return_type: "Int".to_string(),
448 effects: vec![],
449 desc: None,
450 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::FnCall(
451 Box::new(Spanned::bare(Expr::Ident("Console".to_string()))),
452 vec![Spanned::bare(Expr::Ident("x".to_string()))],
453 )))),
454 resolution: None,
455 };
456 resolve_fn(
457 &mut fd,
458 &TypeInfo {
459 variants: HashMap::new(),
460 variant_parents: HashMap::new(),
461 records: HashMap::new(),
462 },
463 );
464 match fd.body.tail_expr() {
465 Some(Spanned {
466 node: Expr::FnCall(func, args),
467 ..
468 }) => {
469 assert_eq!(func.node, Expr::Ident("Console".to_string()));
470 assert_eq!(
471 args[0].node,
472 Expr::Resolved {
473 slot: 0,
474 name: "x".to_string(),
475 last_use: AnnotBool(false)
476 }
477 );
478 }
479 other => panic!("unexpected body: {:?}", other),
480 }
481 }
482
483 #[test]
484 fn resolves_val_in_block_body() {
485 let mut fd = FnDef {
486 name: "f".to_string(),
487 line: 1,
488 params: vec![("x".to_string(), "Int".to_string())],
489 return_type: "Int".to_string(),
490 effects: vec![],
491 desc: None,
492 body: Rc::new(FnBody::Block(vec![
493 Stmt::Binding(
494 "y".to_string(),
495 None,
496 Spanned::bare(Expr::BinOp(
497 BinOp::Add,
498 Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
499 Box::new(Spanned::bare(Expr::Literal(Literal::Int(1)))),
500 )),
501 ),
502 Stmt::Expr(Spanned::bare(Expr::Ident("y".to_string()))),
503 ])),
504 resolution: None,
505 };
506 resolve_fn(
507 &mut fd,
508 &TypeInfo {
509 variants: HashMap::new(),
510 variant_parents: HashMap::new(),
511 records: HashMap::new(),
512 },
513 );
514 let res = fd.resolution.as_ref().unwrap();
515 assert_eq!(res.local_slots["x"], 0);
516 assert_eq!(res.local_slots["y"], 1);
517 assert_eq!(res.local_count, 2);
518
519 let stmts = fd.body.stmts();
520 match &stmts[0] {
522 Stmt::Binding(
523 _,
524 _,
525 Spanned {
526 node: Expr::BinOp(_, left, _),
527 ..
528 },
529 ) => {
530 assert_eq!(
531 left.node,
532 Expr::Resolved {
533 slot: 0,
534 name: "x".to_string(),
535 last_use: AnnotBool(false)
536 }
537 );
538 }
539 other => panic!("unexpected stmt: {:?}", other),
540 }
541 match &stmts[1] {
543 Stmt::Expr(Spanned {
544 node: Expr::Resolved { slot: 1, .. },
545 ..
546 }) => {}
547 other => panic!("unexpected stmt: {:?}", other),
548 }
549 }
550
551 #[test]
552 fn resolves_match_pattern_bindings() {
553 let mut fd = FnDef {
555 name: "f".to_string(),
556 line: 1,
557 params: vec![("x".to_string(), "Int".to_string())],
558 return_type: "Int".to_string(),
559 effects: vec![],
560 desc: None,
561 body: Rc::new(FnBody::from_expr(Spanned::new(
562 Expr::Match {
563 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
564 arms: vec![
565 MatchArm {
566 pattern: Pattern::Constructor(
567 "Result.Ok".to_string(),
568 vec!["v".to_string()],
569 ),
570 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
571 binding_slots: std::sync::OnceLock::new(),
572 },
573 MatchArm {
574 pattern: Pattern::Wildcard,
575 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
576 binding_slots: std::sync::OnceLock::new(),
577 },
578 ],
579 },
580 1,
581 ))),
582 resolution: None,
583 };
584 resolve_fn(
585 &mut fd,
586 &TypeInfo {
587 variants: HashMap::new(),
588 variant_parents: HashMap::new(),
589 records: HashMap::new(),
590 },
591 );
592 let res = fd.resolution.as_ref().unwrap();
593 assert_eq!(res.local_slots["v"], 1);
595
596 match fd.body.tail_expr() {
597 Some(Spanned {
598 node: Expr::Match { arms, .. },
599 ..
600 }) => {
601 assert_eq!(
602 arms[0].body.node,
603 Expr::Resolved {
604 slot: 1,
605 name: "v".to_string(),
606 last_use: AnnotBool(false)
607 }
608 );
609 }
610 other => panic!("unexpected body: {:?}", other),
611 }
612 }
613
614 #[test]
615 fn resolves_match_pattern_bindings_inside_binding_initializer() {
616 let mut fd = FnDef {
617 name: "f".to_string(),
618 line: 1,
619 params: vec![("x".to_string(), "Int".to_string())],
620 return_type: "Int".to_string(),
621 effects: vec![],
622 desc: None,
623 body: Rc::new(FnBody::Block(vec![
624 Stmt::Binding(
625 "result".to_string(),
626 None,
627 Spanned::bare(Expr::Match {
628 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
629 arms: vec![
630 MatchArm {
631 pattern: Pattern::Constructor(
632 "Option.Some".to_string(),
633 vec!["v".to_string()],
634 ),
635 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
636 binding_slots: std::sync::OnceLock::new(),
637 },
638 MatchArm {
639 pattern: Pattern::Wildcard,
640 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
641 binding_slots: std::sync::OnceLock::new(),
642 },
643 ],
644 }),
645 ),
646 Stmt::Expr(Spanned::bare(Expr::Ident("result".to_string()))),
647 ])),
648 resolution: None,
649 };
650
651 resolve_fn(
652 &mut fd,
653 &TypeInfo {
654 variants: HashMap::new(),
655 variant_parents: HashMap::new(),
656 records: HashMap::new(),
657 },
658 );
659 let res = fd.resolution.as_ref().unwrap();
660 assert_eq!(res.local_slots["x"], 0);
661 assert_eq!(res.local_slots["result"], 1);
662 assert_eq!(res.local_slots["v"], 2);
663
664 let stmts = fd.body.stmts();
665 match &stmts[0] {
666 Stmt::Binding(
667 _,
668 _,
669 Spanned {
670 node: Expr::Match { arms, .. },
671 ..
672 },
673 ) => {
674 assert_eq!(
675 arms[0].body.node,
676 Expr::Resolved {
677 slot: 2,
678 name: "v".to_string(),
679 last_use: AnnotBool(false)
680 }
681 );
682 }
683 other => panic!("unexpected stmt: {:?}", other),
684 }
685
686 match &stmts[1] {
687 Stmt::Expr(Spanned {
688 node: Expr::Resolved { slot: 1, .. },
689 ..
690 }) => {}
691 other => panic!("unexpected stmt: {:?}", other),
692 }
693 }
694}