1use std::collections::HashMap;
13use std::sync::Arc as Rc;
14
15use crate::ast::*;
16
17struct TypeInfo {
22 variants: HashMap<(String, String), Vec<String>>,
31 variant_parents: HashMap<String, Vec<String>>,
37 #[allow(dead_code)]
42 records: HashMap<String, Vec<(String, String)>>,
43}
44
45fn build_type_info(items: &[TopLevel]) -> TypeInfo {
46 let mut variants: HashMap<(String, String), Vec<String>> = HashMap::new();
47 let mut variant_parents: HashMap<String, Vec<String>> = HashMap::new();
48 let mut records: HashMap<String, Vec<(String, String)>> = HashMap::new();
49 for item in items {
50 match item {
51 TopLevel::TypeDef(TypeDef::Sum {
52 name: parent,
53 variants: vs,
54 ..
55 }) => {
56 for v in vs {
57 variants.insert((parent.clone(), v.name.clone()), v.fields.clone());
58 variant_parents
59 .entry(v.name.clone())
60 .or_default()
61 .push(parent.clone());
62 }
63 }
64 TopLevel::TypeDef(TypeDef::Product { name, fields, .. }) => {
65 records.insert(name.clone(), fields.clone());
66 }
67 _ => {}
68 }
69 }
70 TypeInfo {
71 variants,
72 variant_parents,
73 records,
74 }
75}
76
77pub fn resolve_program(items: &mut [TopLevel]) {
82 let type_info = build_type_info(items);
83 for item in items.iter_mut() {
84 if let TopLevel::FnDef(fd) = item {
85 resolve_fn(fd, &type_info);
86 }
87 }
88}
89
90fn resolve_fn(fd: &mut FnDef, type_info: &TypeInfo) {
101 let mut state = ResolverState::new(type_info);
102
103 state.scopes.push(HashMap::new());
105 for (param_name, ty_str) in &fd.params {
106 let ty = crate::types::parse_type_str_strict(ty_str).unwrap_or(Type::Invalid);
107 let slot = state.declare(param_name, ty);
108 state.last_alloc.insert(param_name.clone(), slot);
109 }
110
111 let mut body = fd.body.as_ref().clone();
115 state.walk_stmts(body.stmts_mut());
116 fd.body = Rc::new(body);
117
118 let next_slot = state.next_slot;
119 let last_alloc = state.last_alloc;
120 let slot_types = state.slot_types;
121 state.scopes.pop();
122
123 fd.resolution = Some(FnResolution {
124 local_count: next_slot,
125 local_slots: Rc::new(last_alloc),
126 local_slot_types: Rc::new(slot_types),
127 });
128}
129
130struct ResolverState<'a> {
134 type_info: &'a TypeInfo,
135 next_slot: u16,
136 slot_types: Vec<Type>,
137 scopes: Vec<HashMap<String, u16>>,
142 last_alloc: HashMap<String, u16>,
149}
150
151impl<'a> ResolverState<'a> {
152 fn new(type_info: &'a TypeInfo) -> Self {
153 Self {
154 type_info,
155 next_slot: 0,
156 slot_types: Vec::new(),
157 scopes: Vec::new(),
158 last_alloc: HashMap::new(),
159 }
160 }
161
162 fn alloc(&mut self, ty: Type) -> u16 {
164 let idx = self.next_slot;
165 self.next_slot += 1;
166 self.slot_types.push(ty);
167 idx
168 }
169
170 fn declare(&mut self, name: &str, ty: Type) -> u16 {
184 if name == "_" {
185 return u16::MAX;
186 }
187 let slot = self.alloc(ty);
188 if let Some(scope) = self.scopes.last_mut() {
189 scope.insert(name.to_string(), slot);
190 }
191 self.last_alloc.insert(name.to_string(), slot);
192 slot
193 }
194
195 fn lookup(&self, name: &str) -> Option<u16> {
199 for scope in self.scopes.iter().rev() {
200 if let Some(&s) = scope.get(name) {
201 return Some(s);
202 }
203 }
204 None
205 }
206
207 fn walk_stmts(&mut self, stmts: &mut [Stmt]) {
208 for stmt in stmts {
209 match stmt {
210 Stmt::Binding(name, _annot, expr) => {
211 let ty = expr.ty().cloned().unwrap_or(Type::Invalid);
218 self.declare(name, ty);
219 self.walk_expr(expr);
220 }
221 Stmt::Expr(expr) => self.walk_expr(expr),
222 }
223 }
224 }
225
226 fn walk_expr(&mut self, expr: &mut Spanned<Expr>) {
227 match &mut expr.node {
228 Expr::Ident(name) => {
229 if let Some(slot) = self.lookup(name) {
230 expr.node = Expr::Resolved {
231 slot,
232 name: name.clone(),
233 last_use: AnnotBool(false),
234 };
235 }
236 }
237 Expr::Match { subject, arms } => {
238 self.walk_expr(subject);
239 let subject_ty = subject.ty().cloned();
240 for arm in arms.iter_mut() {
241 self.scopes.push(HashMap::new());
242 let slots = self.allocate_pattern(&arm.pattern, subject_ty.as_ref());
243 let _ = arm.binding_slots.set(slots);
244 self.walk_expr(&mut arm.body);
245 self.scopes.pop();
246 }
247 }
248 Expr::FnCall(func, args) => {
249 self.walk_expr(func);
250 for arg in args {
251 self.walk_expr(arg);
252 }
253 }
254 Expr::BinOp(_, l, r) => {
255 self.walk_expr(l);
256 self.walk_expr(r);
257 }
258 Expr::Attr(obj, _) => self.walk_expr(obj),
259 Expr::ErrorProp(inner) => self.walk_expr(inner),
260 Expr::Constructor(_, Some(inner)) => self.walk_expr(inner),
261 Expr::Constructor(_, None) => {}
262 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
263 for it in items {
264 self.walk_expr(it);
265 }
266 }
267 Expr::MapLiteral(entries) => {
268 for (k, v) in entries {
269 self.walk_expr(k);
270 self.walk_expr(v);
271 }
272 }
273 Expr::InterpolatedStr(parts) => {
274 for part in parts {
275 if let StrPart::Parsed(e) = part {
276 self.walk_expr(e);
277 }
278 }
279 }
280 Expr::RecordCreate { fields, .. } => {
281 for (_, e) in fields {
282 self.walk_expr(e);
283 }
284 }
285 Expr::RecordUpdate { base, updates, .. } => {
286 self.walk_expr(base);
287 for (_, e) in updates {
288 self.walk_expr(e);
289 }
290 }
291 Expr::TailCall(boxed) => {
292 for a in &mut boxed.args {
293 self.walk_expr(a);
294 }
295 }
296 Expr::Literal(_) | Expr::Resolved { .. } => {}
297 }
298 }
299
300 fn allocate_pattern(&mut self, pattern: &Pattern, subject_ty: Option<&Type>) -> Vec<u16> {
306 match pattern {
307 Pattern::Ident(name) => {
308 let ty = subject_ty.cloned().unwrap_or(Type::Invalid);
309 vec![self.declare(name, ty)]
310 }
311 Pattern::Cons(head, tail) => {
312 let elem_ty = match subject_ty {
313 Some(Type::List(inner)) => (**inner).clone(),
314 _ => Type::Invalid,
315 };
316 let list_ty = Type::List(Box::new(elem_ty.clone()));
317 vec![self.declare(head, elem_ty), self.declare(tail, list_ty)]
318 }
319 Pattern::Constructor(name, bindings) => {
320 let bare = name.rsplit('.').next().unwrap_or(name);
321 let parent_hint: Option<String> = match (subject_ty, name.split_once('.')) {
322 (Some(Type::Named(parent)), _) => Some(parent.clone()),
323 (_, Some((parent, _))) => Some(parent.to_string()),
324 _ => self
325 .type_info
326 .variant_parents
327 .get(bare)
328 .and_then(|parents| {
329 if parents.len() == 1 {
330 Some(parents[0].clone())
331 } else {
332 None
333 }
334 }),
335 };
336 let field_tys: Vec<Type> = match (bare, subject_ty) {
337 ("Ok", Some(Type::Result(t, _))) => vec![(**t).clone()],
338 ("Err", Some(Type::Result(_, e))) => vec![(**e).clone()],
339 ("Some", Some(Type::Option(inner))) => vec![(**inner).clone()],
340 ("None", _) => Vec::new(),
341 _ => parent_hint
342 .and_then(|p| self.type_info.variants.get(&(p, bare.to_string())))
343 .map(|fields| {
344 fields
345 .iter()
346 .map(|s| {
347 crate::types::parse_type_str_strict(s).unwrap_or(Type::Invalid)
348 })
349 .collect()
350 })
351 .unwrap_or_else(|| vec![Type::Invalid; bindings.len()]),
352 };
353 bindings
354 .iter()
355 .enumerate()
356 .map(|(i, name)| {
357 let ty = field_tys.get(i).cloned().unwrap_or(Type::Invalid);
358 self.declare(name, ty)
359 })
360 .collect()
361 }
362 Pattern::Tuple(items) => {
363 let elem_tys: Vec<Type> = match subject_ty {
364 Some(Type::Tuple(elems)) if elems.len() == items.len() => elems.to_vec(),
365 _ => vec![Type::Invalid; items.len()],
366 };
367 let mut slots = Vec::new();
368 for (item, elem_ty) in items.iter().zip(elem_tys.iter()) {
369 slots.extend(self.allocate_pattern(item, Some(elem_ty)));
370 }
371 slots
372 }
373 Pattern::Wildcard | Pattern::Literal(_) | Pattern::EmptyList => Vec::new(),
374 }
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn resolves_param_to_slot() {
384 let mut fd = FnDef {
385 name: "add".to_string(),
386 line: 1,
387 params: vec![
388 ("a".to_string(), "Int".to_string()),
389 ("b".to_string(), "Int".to_string()),
390 ],
391 return_type: "Int".to_string(),
392 effects: vec![],
393 desc: None,
394 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::BinOp(
395 BinOp::Add,
396 Box::new(Spanned::bare(Expr::Ident("a".to_string()))),
397 Box::new(Spanned::bare(Expr::Ident("b".to_string()))),
398 )))),
399 resolution: None,
400 };
401 resolve_fn(
402 &mut fd,
403 &TypeInfo {
404 variants: HashMap::new(),
405 variant_parents: HashMap::new(),
406 records: HashMap::new(),
407 },
408 );
409 let res = fd.resolution.as_ref().unwrap();
410 assert_eq!(res.local_slots["a"], 0);
411 assert_eq!(res.local_slots["b"], 1);
412 assert_eq!(res.local_count, 2);
413
414 match fd.body.tail_expr() {
415 Some(Spanned {
416 node: Expr::BinOp(_, left, right),
417 ..
418 }) => {
419 assert_eq!(
420 left.node,
421 Expr::Resolved {
422 slot: 0,
423 name: "a".to_string(),
424 last_use: AnnotBool(false)
425 }
426 );
427 assert_eq!(
428 right.node,
429 Expr::Resolved {
430 slot: 1,
431 name: "b".to_string(),
432 last_use: AnnotBool(false)
433 }
434 );
435 }
436 other => panic!("unexpected body: {:?}", other),
437 }
438 }
439
440 #[test]
441 fn leaves_globals_as_ident() {
442 let mut fd = FnDef {
443 name: "f".to_string(),
444 line: 1,
445 params: vec![("x".to_string(), "Int".to_string())],
446 return_type: "Int".to_string(),
447 effects: vec![],
448 desc: None,
449 body: Rc::new(FnBody::from_expr(Spanned::bare(Expr::FnCall(
450 Box::new(Spanned::bare(Expr::Ident("Console".to_string()))),
451 vec![Spanned::bare(Expr::Ident("x".to_string()))],
452 )))),
453 resolution: None,
454 };
455 resolve_fn(
456 &mut fd,
457 &TypeInfo {
458 variants: HashMap::new(),
459 variant_parents: HashMap::new(),
460 records: HashMap::new(),
461 },
462 );
463 match fd.body.tail_expr() {
464 Some(Spanned {
465 node: Expr::FnCall(func, args),
466 ..
467 }) => {
468 assert_eq!(func.node, Expr::Ident("Console".to_string()));
469 assert_eq!(
470 args[0].node,
471 Expr::Resolved {
472 slot: 0,
473 name: "x".to_string(),
474 last_use: AnnotBool(false)
475 }
476 );
477 }
478 other => panic!("unexpected body: {:?}", other),
479 }
480 }
481
482 #[test]
483 fn resolves_val_in_block_body() {
484 let mut fd = FnDef {
485 name: "f".to_string(),
486 line: 1,
487 params: vec![("x".to_string(), "Int".to_string())],
488 return_type: "Int".to_string(),
489 effects: vec![],
490 desc: None,
491 body: Rc::new(FnBody::Block(vec![
492 Stmt::Binding(
493 "y".to_string(),
494 None,
495 Spanned::bare(Expr::BinOp(
496 BinOp::Add,
497 Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
498 Box::new(Spanned::bare(Expr::Literal(Literal::Int(1)))),
499 )),
500 ),
501 Stmt::Expr(Spanned::bare(Expr::Ident("y".to_string()))),
502 ])),
503 resolution: None,
504 };
505 resolve_fn(
506 &mut fd,
507 &TypeInfo {
508 variants: HashMap::new(),
509 variant_parents: HashMap::new(),
510 records: HashMap::new(),
511 },
512 );
513 let res = fd.resolution.as_ref().unwrap();
514 assert_eq!(res.local_slots["x"], 0);
515 assert_eq!(res.local_slots["y"], 1);
516 assert_eq!(res.local_count, 2);
517
518 let stmts = fd.body.stmts();
519 match &stmts[0] {
521 Stmt::Binding(
522 _,
523 _,
524 Spanned {
525 node: Expr::BinOp(_, left, _),
526 ..
527 },
528 ) => {
529 assert_eq!(
530 left.node,
531 Expr::Resolved {
532 slot: 0,
533 name: "x".to_string(),
534 last_use: AnnotBool(false)
535 }
536 );
537 }
538 other => panic!("unexpected stmt: {:?}", other),
539 }
540 match &stmts[1] {
542 Stmt::Expr(Spanned {
543 node: Expr::Resolved { slot: 1, .. },
544 ..
545 }) => {}
546 other => panic!("unexpected stmt: {:?}", other),
547 }
548 }
549
550 #[test]
551 fn resolves_match_pattern_bindings() {
552 let mut fd = FnDef {
554 name: "f".to_string(),
555 line: 1,
556 params: vec![("x".to_string(), "Int".to_string())],
557 return_type: "Int".to_string(),
558 effects: vec![],
559 desc: None,
560 body: Rc::new(FnBody::from_expr(Spanned::new(
561 Expr::Match {
562 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
563 arms: vec![
564 MatchArm {
565 pattern: Pattern::Constructor(
566 "Result.Ok".to_string(),
567 vec!["v".to_string()],
568 ),
569 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
570 binding_slots: std::sync::OnceLock::new(),
571 },
572 MatchArm {
573 pattern: Pattern::Wildcard,
574 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
575 binding_slots: std::sync::OnceLock::new(),
576 },
577 ],
578 },
579 1,
580 ))),
581 resolution: None,
582 };
583 resolve_fn(
584 &mut fd,
585 &TypeInfo {
586 variants: HashMap::new(),
587 variant_parents: HashMap::new(),
588 records: HashMap::new(),
589 },
590 );
591 let res = fd.resolution.as_ref().unwrap();
592 assert_eq!(res.local_slots["v"], 1);
594
595 match fd.body.tail_expr() {
596 Some(Spanned {
597 node: Expr::Match { arms, .. },
598 ..
599 }) => {
600 assert_eq!(
601 arms[0].body.node,
602 Expr::Resolved {
603 slot: 1,
604 name: "v".to_string(),
605 last_use: AnnotBool(false)
606 }
607 );
608 }
609 other => panic!("unexpected body: {:?}", other),
610 }
611 }
612
613 #[test]
614 fn resolves_match_pattern_bindings_inside_binding_initializer() {
615 let mut fd = FnDef {
616 name: "f".to_string(),
617 line: 1,
618 params: vec![("x".to_string(), "Int".to_string())],
619 return_type: "Int".to_string(),
620 effects: vec![],
621 desc: None,
622 body: Rc::new(FnBody::Block(vec![
623 Stmt::Binding(
624 "result".to_string(),
625 None,
626 Spanned::bare(Expr::Match {
627 subject: Box::new(Spanned::bare(Expr::Ident("x".to_string()))),
628 arms: vec![
629 MatchArm {
630 pattern: Pattern::Constructor(
631 "Option.Some".to_string(),
632 vec!["v".to_string()],
633 ),
634 body: Box::new(Spanned::bare(Expr::Ident("v".to_string()))),
635 binding_slots: std::sync::OnceLock::new(),
636 },
637 MatchArm {
638 pattern: Pattern::Wildcard,
639 body: Box::new(Spanned::bare(Expr::Literal(Literal::Int(0)))),
640 binding_slots: std::sync::OnceLock::new(),
641 },
642 ],
643 }),
644 ),
645 Stmt::Expr(Spanned::bare(Expr::Ident("result".to_string()))),
646 ])),
647 resolution: None,
648 };
649
650 resolve_fn(
651 &mut fd,
652 &TypeInfo {
653 variants: HashMap::new(),
654 variant_parents: HashMap::new(),
655 records: HashMap::new(),
656 },
657 );
658 let res = fd.resolution.as_ref().unwrap();
659 assert_eq!(res.local_slots["x"], 0);
660 assert_eq!(res.local_slots["result"], 1);
661 assert_eq!(res.local_slots["v"], 2);
662
663 let stmts = fd.body.stmts();
664 match &stmts[0] {
665 Stmt::Binding(
666 _,
667 _,
668 Spanned {
669 node: Expr::Match { arms, .. },
670 ..
671 },
672 ) => {
673 assert_eq!(
674 arms[0].body.node,
675 Expr::Resolved {
676 slot: 2,
677 name: "v".to_string(),
678 last_use: AnnotBool(false)
679 }
680 );
681 }
682 other => panic!("unexpected stmt: {:?}", other),
683 }
684
685 match &stmts[1] {
686 Stmt::Expr(Spanned {
687 node: Expr::Resolved { slot: 1, .. },
688 ..
689 }) => {}
690 other => panic!("unexpected stmt: {:?}", other),
691 }
692 }
693}