1use crate::{
6 diagnostics::Diagnostics,
7 fun::{num_to_name, Adt, Book, Ctx, FanKind, MatchRule, Name, Num, Op, Pattern, Tag, Term, Type},
8 maybe_grow,
9};
10use std::collections::{BTreeMap, BTreeSet, HashMap};
11
12impl Ctx<'_> {
13 pub fn type_check(&mut self) -> Result<(), Diagnostics> {
14 let types = infer_book(self.book, &mut self.info)?;
15
16 for def in self.book.defs.values_mut() {
17 def.typ = types[&def.name].instantiate(&mut VarGen::default());
18 }
19
20 Ok(())
21 }
22}
23
24type ProgramTypes = HashMap<Name, Scheme>;
25
26#[derive(Clone, Debug)]
28struct Scheme(Vec<Name>, Type);
29
30#[derive(Clone, Default, Debug)]
32struct Subst(BTreeMap<Name, Type>);
33
34#[derive(Clone, Default, Debug)]
36struct TypeEnv(BTreeMap<Name, Scheme>);
37
38#[derive(Default)]
40struct VarGen(usize);
41
42struct RecGroups(Vec<Vec<Name>>);
44
45impl Type {
48 fn free_type_vars(&self) -> BTreeSet<Name> {
49 maybe_grow(|| match self {
50 Type::Var(x) => BTreeSet::from([x.clone()]),
51 Type::Ctr(_, ts) | Type::Tup(ts) => ts.iter().flat_map(|t| t.free_type_vars()).collect(),
52 Type::Arr(t1, t2) => t1.free_type_vars().union(&t2.free_type_vars()).cloned().collect(),
53 Type::Number(t) | Type::Integer(t) => t.free_type_vars(),
54 Type::U24 | Type::F24 | Type::I24 | Type::None | Type::Any | Type::Hole => BTreeSet::new(),
55 })
56 }
57
58 fn subst(&self, subst: &Subst) -> Type {
59 maybe_grow(|| match self {
60 Type::Var(nam) => match subst.0.get(nam) {
61 Some(new) => new.clone(),
62 None => self.clone(),
63 },
64 Type::Ctr(name, ts) => Type::Ctr(name.clone(), ts.iter().map(|t| t.subst(subst)).collect()),
65 Type::Arr(t1, t2) => Type::Arr(Box::new(t1.subst(subst)), Box::new(t2.subst(subst))),
66 Type::Tup(els) => Type::Tup(els.iter().map(|t| t.subst(subst)).collect()),
67 Type::Number(t) => Type::Number(Box::new(t.subst(subst))),
68 Type::Integer(t) => Type::Integer(Box::new(t.subst(subst))),
69 t @ (Type::U24 | Type::F24 | Type::I24 | Type::None | Type::Any | Type::Hole) => t.clone(),
70 })
71 }
72
73 fn generalize(&self, env: &TypeEnv) -> Scheme {
77 let vars_env = env.free_type_vars();
78 let vars_t = self.free_type_vars();
79 let vars = vars_t.difference(&vars_env).cloned().collect();
80 Scheme(vars, self.clone())
81 }
82}
83
84impl Scheme {
85 fn free_type_vars(&self) -> BTreeSet<Name> {
86 let vars = self.1.free_type_vars();
87 let bound_vars = self.0.iter().cloned().collect();
88 vars.difference(&bound_vars).cloned().collect()
89 }
90
91 fn subst(&self, subst: &Subst) -> Scheme {
92 let mut subst = subst.clone();
93 for x in self.0.iter() {
94 subst.0.remove(x);
95 }
96 let t = self.1.subst(&subst);
97 Scheme(self.0.clone(), t)
98 }
99
100 fn instantiate(&self, var_gen: &mut VarGen) -> Type {
103 let new_vars = self.0.iter().map(|_| var_gen.fresh());
104 let subst = Subst(self.0.iter().cloned().zip(new_vars).collect());
105 self.1.subst(&subst)
106 }
107}
108
109impl Subst {
110 fn compose(mut self, other: Subst) -> Subst {
114 let other = other.0.into_iter().map(|(x, t)| (x, t.subst(&self))).collect::<Vec<_>>();
115 self.0.extend(other);
116 self
117 }
118}
119
120impl TypeEnv {
121 fn free_type_vars(&self) -> BTreeSet<Name> {
122 let mut vars = BTreeSet::new();
123 for scheme in self.0.values() {
124 let scheme_vars = scheme.free_type_vars();
125 vars = vars.union(&scheme_vars).cloned().collect();
126 }
127 vars
128 }
129
130 fn subst(&self, subst: &Subst) -> TypeEnv {
131 let env = self.0.iter().map(|(x, scheme)| (x.clone(), scheme.subst(subst))).collect();
132 TypeEnv(env)
133 }
134
135 fn insert(&mut self, name: Name, scheme: Scheme) {
136 self.0.insert(name, scheme);
137 }
138
139 fn add_binds<'a>(
140 &mut self,
141 bnd: impl IntoIterator<Item = (&'a Option<Name>, Scheme)>,
142 ) -> Vec<(Name, Option<Scheme>)> {
143 let mut old_bnd = vec![];
144 for (name, scheme) in bnd {
145 if let Some(name) = name {
146 let old = self.0.insert(name.clone(), scheme);
147 old_bnd.push((name.clone(), old));
148 }
149 }
150 old_bnd
151 }
152
153 fn pop_binds(&mut self, old_bnd: Vec<(Name, Option<Scheme>)>) {
154 for (name, scheme) in old_bnd {
155 if let Some(scheme) = scheme {
156 self.0.insert(name, scheme);
157 }
158 }
159 }
160}
161
162impl VarGen {
163 fn fresh(&mut self) -> Type {
164 let x = self.fresh_name();
165 Type::Var(x)
166 }
167
168 fn fresh_name(&mut self) -> Name {
169 let x = num_to_name(self.0 as u64);
170 self.0 += 1;
171 Name::new(x)
172 }
173}
174
175impl RecGroups {
176 fn from_book(book: &Book) -> RecGroups {
177 type DependencyGraph<'a> = BTreeMap<&'a Name, BTreeSet<&'a Name>>;
178
179 fn collect_dependencies<'a>(
180 term: &'a Term,
181 book: &'a Book,
182 scope: &mut Vec<Name>,
183 deps: &mut BTreeSet<&'a Name>,
184 ) {
185 if let Term::Ref { nam } = term {
186 if book.ctrs.contains_key(nam) || book.hvm_defs.contains_key(nam) || !book.defs[nam].check {
187 } else {
189 deps.insert(nam);
190 }
191 }
192 for (child, binds) in term.children_with_binds() {
193 scope.extend(binds.clone().flatten().cloned());
194 collect_dependencies(child, book, scope, deps);
195 scope.truncate(scope.len() - binds.flatten().count());
196 }
197 }
198
199 fn strong_connect<'a>(
201 v: &'a Name,
202 deps: &DependencyGraph<'a>,
203 index: &mut usize,
204 index_map: &mut BTreeMap<&'a Name, usize>,
205 low_link: &mut BTreeMap<&'a Name, usize>,
206 stack: &mut Vec<&'a Name>,
207 components: &mut Vec<BTreeSet<Name>>,
208 ) {
209 maybe_grow(|| {
210 index_map.insert(v, *index);
211 low_link.insert(v, *index);
212 *index += 1;
213 stack.push(v);
214
215 if let Some(neighbors) = deps.get(v) {
216 for w in neighbors {
217 if !index_map.contains_key(w) {
218 strong_connect(w, deps, index, index_map, low_link, stack, components);
220 low_link.insert(v, low_link[v].min(low_link[w]));
221 } else if stack.contains(w) {
222 low_link.insert(v, low_link[v].min(index_map[w]));
224 } else {
225 }
228 }
229 }
230
231 if low_link[v] == index_map[v] {
233 let mut component = BTreeSet::new();
234 while let Some(w) = stack.pop() {
235 component.insert(w.clone());
236 if w == v {
237 break;
238 }
239 }
240 components.push(component);
241 }
242 })
243 }
244
245 let mut deps = DependencyGraph::default();
247 for (name, def) in &book.defs {
248 if book.ctrs.contains_key(name) || !def.check {
249 continue;
251 }
252 let mut fn_deps = Default::default();
253 collect_dependencies(&def.rule().body, book, &mut vec![], &mut fn_deps);
254 deps.insert(name, fn_deps);
255 }
256
257 let mut index = 0;
258 let mut stack = Vec::new();
259 let mut index_map = BTreeMap::new();
260 let mut low_link = BTreeMap::new();
261 let mut components = Vec::new();
262 for name in deps.keys() {
263 if !index_map.contains_key(name) {
264 strong_connect(name, &deps, &mut index, &mut index_map, &mut low_link, &mut stack, &mut components);
265 }
266 }
267 let components = components.into_iter().map(|x| x.into_iter().collect()).collect();
268 RecGroups(components)
269 }
270}
271
272fn infer_book(book: &Book, diags: &mut Diagnostics) -> Result<ProgramTypes, Diagnostics> {
274 let groups = RecGroups::from_book(book);
275 let mut env = TypeEnv::default();
276 let mut types = ProgramTypes::default();
279
280 for adt in book.adts.values() {
282 for ctr in adt.ctrs.values() {
283 types.insert(ctr.name.clone(), ctr.typ.generalize(&TypeEnv::default()));
284 }
285 }
286 for def in book.defs.values() {
288 if !def.check {
289 types.insert(def.name.clone(), def.typ.generalize(&TypeEnv::default()));
290 }
291 }
292 for def in book.hvm_defs.values() {
294 types.insert(def.name.clone(), def.typ.generalize(&TypeEnv::default()));
295 }
296
297 for group in &groups.0 {
299 infer_group(&mut env, book, group, &mut types, diags)?;
300 }
301 Ok(types)
302}
303
304fn infer_group(
305 env: &mut TypeEnv,
306 book: &Book,
307 group: &[Name],
308 types: &mut ProgramTypes,
309 diags: &mut Diagnostics,
310) -> Result<(), Diagnostics> {
311 let var_gen = &mut VarGen::default();
312 let tvs = group.iter().map(|_| var_gen.fresh()).collect::<Vec<_>>();
314 for (name, tv) in group.iter().zip(tvs.iter()) {
315 env.insert(name.clone(), Scheme(vec![], tv.clone()));
316 }
317
318 let mut ss = vec![];
320 let mut inf_ts = vec![];
321 let mut exp_ts = vec![];
322 for name in group {
323 let def = &book.defs[name];
324 let (s, t) = infer(env, book, types, &def.rule().body, var_gen).map_err(|e| {
325 diags.add_function_error(e, name.clone(), def.source.clone());
326 std::mem::take(diags)
327 })?;
328 let t = t.subst(&s);
329 ss.push(s);
330 inf_ts.push(t);
331 exp_ts.push(&def.typ);
332 }
333
334 for name in group.iter() {
337 env.0.remove(name);
338 }
339
340 let mut s = ss.into_iter().fold(Subst::default(), |acc, s| acc.compose(s));
342 let mut ts = vec![];
343 for ((bod_t, tv), nam) in inf_ts.into_iter().zip(tvs.iter()).zip(group.iter()) {
344 let (t, s2) = unify_term(&tv.subst(&s), &bod_t, &book.defs[nam].rule().body)?;
345 ts.push(t);
346 s = s.compose(s2);
347 }
348 let ts = ts.into_iter().map(|t| t.subst(&s)).collect::<Vec<_>>();
349
350 for ((name, exp_t), inf_t) in group.iter().zip(exp_ts.iter()).zip(ts.iter()) {
352 let t = specialize(inf_t, exp_t).map_err(|e| {
353 diags.add_function_error(e, name.clone(), book.defs[name].source.clone());
354 std::mem::take(diags)
355 })?;
356 types.insert(name.clone(), t.generalize(&TypeEnv::default()));
357 }
358
359 diags.fatal(())
360}
361
362fn infer(
369 env: &mut TypeEnv,
370 book: &Book,
371 types: &ProgramTypes,
372 term: &Term,
373 var_gen: &mut VarGen,
374) -> Result<(Subst, Type), String> {
375 let res = maybe_grow(|| match term {
376 Term::Var { nam } | Term::Ref { nam } => {
377 if let Some(scheme) = env.0.get(nam) {
378 Ok::<_, String>((Subst::default(), scheme.instantiate(var_gen)))
379 } else if let Some(scheme) = types.get(nam) {
380 Ok((Subst::default(), scheme.instantiate(var_gen)))
381 } else {
382 unreachable!("unbound name '{}'", nam)
383 }
384 }
385 Term::Lam { tag: Tag::Static, pat, bod } => match pat.as_ref() {
386 Pattern::Var(nam) => {
387 let tv = var_gen.fresh();
388 let old_bnd = env.add_binds([(nam, Scheme(vec![], tv.clone()))]);
389 let (s, bod_t) = infer(env, book, types, bod, var_gen)?;
390 env.pop_binds(old_bnd);
391 let var_t = tv.subst(&s);
392 Ok((s, Type::Arr(Box::new(var_t), Box::new(bod_t))))
393 }
394 _ => unreachable!("{}", term),
395 },
396 Term::App { tag: Tag::Static, fun, arg } => {
397 let (s1, fun_t) = infer(env, book, types, fun, var_gen)?;
398 let (s2, arg_t) = infer(&mut env.subst(&s1), book, types, arg, var_gen)?;
399 let app_t = var_gen.fresh();
400 let (_, s3) = unify_term(&fun_t.subst(&s2), &Type::Arr(Box::new(arg_t), Box::new(app_t.clone())), fun)?;
401 let t = app_t.subst(&s3);
402 Ok((s3.compose(s2).compose(s1), t))
403 }
404 Term::Let { pat, val, nxt } => match pat.as_ref() {
405 Pattern::Var(nam) => {
406 let (s1, val_t) = infer(env, book, types, val, var_gen)?;
407 let old_bnd = env.add_binds([(nam, val_t.generalize(&env.subst(&s1)))]);
408 let (s2, nxt_t) = infer(&mut env.subst(&s1), book, types, nxt, var_gen)?;
409 env.pop_binds(old_bnd);
410 Ok((s2.compose(s1), nxt_t))
411 }
412 Pattern::Fan(FanKind::Tup, Tag::Static, _) => {
413 debug_assert!(!(pat.has_unscoped() || pat.has_nested()));
416 let (s1, val_t) = infer(env, book, types, val, var_gen)?;
417
418 let tvs = pat.binds().map(|_| var_gen.fresh()).collect::<Vec<_>>();
419 let old_bnd = env.add_binds(pat.binds().zip(tvs.iter().map(|tv| Scheme(vec![], tv.clone()))));
420 let (s2, nxt_t) = infer(&mut env.subst(&s1), book, types, nxt, var_gen)?;
421 env.pop_binds(old_bnd);
422 let tvs = tvs.into_iter().map(|tv| tv.subst(&s2)).collect::<Vec<_>>();
423 let (_, s3) = unify_term(&val_t, &Type::Tup(tvs), val)?;
424 Ok((s3.compose(s2).compose(s1), nxt_t))
425 }
426 Pattern::Fan(FanKind::Dup, Tag::Auto, _) => {
427 debug_assert!(!(pat.has_unscoped() || pat.has_nested()));
430 let (s1, mut val_t) = infer(env, book, types, val, var_gen)?;
431 let tvs = pat.binds().map(|_| var_gen.fresh()).collect::<Vec<_>>();
432 let old_bnd = env.add_binds(pat.binds().zip(tvs.iter().map(|tv| Scheme(vec![], tv.clone()))));
433 let (mut s2, nxt_t) = infer(&mut env.subst(&s1), book, types, nxt, var_gen)?;
434 env.pop_binds(old_bnd);
435 for tv in tvs {
436 let (val_t_, s) = unify_term(&val_t, &tv.subst(&s2), val)?;
437 val_t = val_t_;
438 s2 = s2.compose(s);
439 }
440 Ok((s2.compose(s1), nxt_t))
441 }
442 _ => unreachable!(),
443 },
444
445 Term::Mat { bnd: _, arg, with_bnd: _, with_arg: _, arms } => {
446 let (s1, t1) = infer(env, book, types, arg, var_gen)?;
448
449 let adt_name = book.ctrs.get(arms[0].0.as_ref().unwrap()).unwrap();
451 let adt = &book.adts[adt_name];
452 let (adt_s, adt_t) = instantiate_adt(adt, var_gen)?;
453
454 let (s2, nxt_t) = infer_match_cases(env.subst(&s1), book, types, adt, arms, &adt_s, var_gen)?;
458
459 let (_, s3) = unify_term(&t1, &adt_t.subst(&s2), arg)?;
461 Ok((s3.compose(s2).compose(s1), nxt_t))
462 }
463
464 Term::Num { val } => {
465 let t = match val {
466 Num::U24(_) => Type::U24,
467 Num::I24(_) => Type::I24,
468 Num::F24(_) => Type::F24,
469 };
470 Ok((Subst::default(), t))
471 }
472 Term::Oper { opr, fst, snd } => {
473 let (s1, t1) = infer(env, book, types, fst, var_gen)?;
474 let (s2, t2) = infer(&mut env.subst(&s1), book, types, snd, var_gen)?;
475 let (t2, s3) = unify_term(&t2.subst(&s1), &t1.subst(&s2), term)?;
476 let s_args = s3.compose(s2).compose(s1);
477 let t_args = t2.subst(&s_args);
478 let tv = var_gen.fresh();
480 let (t_opr, s_opr) = match opr {
481 Op::ADD | Op::SUB | Op::MUL | Op::DIV => {
483 unify_term(&t_args, &Type::Number(Box::new(tv.clone())), term)?
484 }
485 Op::EQ | Op::NEQ | Op::LT | Op::GT | Op::GE | Op::LE => {
486 let (_, s) = unify_term(&t_args, &Type::Number(Box::new(tv.clone())), term)?;
487 (Type::U24, s)
488 }
489 Op::REM | Op::AND | Op::OR | Op::XOR | Op::SHL | Op::SHR => {
491 unify_term(&t_args, &Type::Integer(Box::new(tv.clone())), term)?
492 }
493 Op::POW => unify_term(&t_args, &Type::F24, term)?,
495 };
496 let t = t_opr.subst(&s_opr);
497 Ok((s_opr.compose(s_args), t))
498 }
499 Term::Swt { bnd: _, arg, with_bnd: _, with_arg: _, pred, arms } => {
500 let (s1, t1) = infer(env, book, types, arg, var_gen)?;
501 let (_, s2) = unify_term(&t1, &Type::U24, arg)?;
502 let s_arg = s2.compose(s1);
503 let mut env = env.subst(&s_arg);
504
505 let mut ss_nums = vec![];
506 let mut ts_nums = vec![];
507 for arm in arms.iter().rev().skip(1) {
508 let (s, t) = infer(&mut env, book, types, arm, var_gen)?;
509 env = env.subst(&s);
510 ss_nums.push(s);
511 ts_nums.push(t);
512 }
513 let old_bnd = env.add_binds([(pred, Scheme(vec![], Type::U24))]);
514 let (s_succ, t_succ) = infer(&mut env, book, types, &arms[1], var_gen)?;
515 env.pop_binds(old_bnd);
516
517 let s_arms = ss_nums.into_iter().fold(s_succ, |acc, s| acc.compose(s));
518 let mut t_swt = t_succ;
519 let mut s_swt = Subst::default();
520 for t_num in ts_nums {
521 let (t, s) = unify_term(&t_swt, &t_num, term)?;
522 t_swt = t;
523 s_swt = s.compose(s_swt);
524 }
525
526 let s = s_swt.compose(s_arms).compose(s_arg);
527 let t = t_swt.subst(&s);
528 Ok((s, t))
529 }
530
531 Term::Fan { fan: FanKind::Tup, tag: Tag::Static, els } => {
532 let res = els.iter().map(|el| infer(env, book, types, el, var_gen)).collect::<Result<Vec<_>, _>>()?;
533 let (ss, ts): (Vec<Subst>, Vec<Type>) = res.into_iter().unzip();
534 let t = Type::Tup(ts);
535 let s = ss.into_iter().fold(Subst::default(), |acc, s| acc.compose(s));
536 Ok((s, t))
537 }
538 Term::Era => Ok((Subst::default(), Type::None)),
539 Term::Fan { .. } | Term::Lam { tag: _, .. } | Term::App { tag: _, .. } | Term::Link { .. } => {
540 unreachable!("'{term}' while type checking. Should never occur in checked functions")
541 }
542 Term::Use { .. }
543 | Term::With { .. }
544 | Term::Ask { .. }
545 | Term::Nat { .. }
546 | Term::Str { .. }
547 | Term::List { .. }
548 | Term::Fold { .. }
549 | Term::Bend { .. }
550 | Term::Open { .. }
551 | Term::Def { .. }
552 | Term::Err => unreachable!("'{term}' while type checking. Should have been removed in earlier pass"),
553 })?;
554 Ok(res)
555}
556
557fn instantiate_adt(adt: &Adt, var_gen: &mut VarGen) -> Result<(Subst, Type), String> {
561 let tvs = adt.vars.iter().map(|_| var_gen.fresh());
562 let s = Subst(adt.vars.iter().zip(tvs).map(|(x, t)| (x.clone(), t)).collect());
563 let t = Type::Ctr(adt.name.clone(), adt.vars.iter().cloned().map(Type::Var).collect());
564 let t = t.subst(&s);
565 Ok((s, t))
566}
567
568fn infer_match_cases(
569 mut env: TypeEnv,
570 book: &Book,
571 types: &ProgramTypes,
572 adt: &Adt,
573 arms: &[MatchRule],
574 adt_s: &Subst,
575 var_gen: &mut VarGen,
576) -> Result<(Subst, Type), String> {
577 maybe_grow(|| {
578 if let Some(((ctr_nam, vars, bod), rest)) = arms.split_first() {
579 let ctr = &adt.ctrs[ctr_nam.as_ref().unwrap()];
580 let tvs = vars.iter().map(|_| var_gen.fresh()).collect::<Vec<_>>();
582
583 let old_bnd = env.add_binds(vars.iter().zip(tvs.iter().map(|tv| Scheme(vec![], tv.clone()))));
585 let (s1, t1) = infer(&mut env, book, types, bod, var_gen)?;
586 env.pop_binds(old_bnd);
587 let inf_ts = tvs.into_iter().map(|tv| tv.subst(&s1)).collect::<Vec<_>>();
588 let exp_ts = ctr.fields.iter().map(|f| f.typ.subst(adt_s)).collect::<Vec<_>>();
589 let s2 = unify_fields(inf_ts.iter().zip(exp_ts.iter()), bod)?;
590
591 let s = s2.compose(s1);
593 let (s_rest, t_rest) = infer_match_cases(env.subst(&s), book, types, adt, rest, adt_s, var_gen)?;
594 let (t_final, s_final) = unify_term(&t1.subst(&s), &t_rest, bod)?;
595
596 Ok((s_final.compose(s_rest).compose(s), t_final))
597 } else {
598 Ok((Subst::default(), var_gen.fresh()))
599 }
600 })
601}
602
603fn unify_fields<'a>(ts: impl Iterator<Item = (&'a Type, &'a Type)>, ctx: &Term) -> Result<Subst, String> {
604 let ss = ts.map(|(inf, exp)| unify_term(inf, exp, ctx)).collect::<Result<Vec<_>, _>>()?;
605 let mut s = Subst::default();
606 for (_, s2) in ss.into_iter().rev() {
607 s = s.compose(s2);
608 }
609 Ok(s)
610}
611
612fn unify_term(t1: &Type, t2: &Type, ctx: &Term) -> Result<(Type, Subst), String> {
613 match unify(t1, t2) {
614 Ok((t, s)) => Ok((t, s)),
615 Err(msg) => Err(format!("In {ctx}:\n Can't unify '{t1}' and '{t2}'.{msg}")),
616 }
617}
618
619fn unify(t1: &Type, t2: &Type) -> Result<(Type, Subst), String> {
620 maybe_grow(|| match (t1, t2) {
621 (t, Type::Hole) | (Type::Hole, t) => Ok((t.clone(), Subst::default())),
622 (t, Type::Var(x)) | (Type::Var(x), t) => {
623 if let Type::Var(y) = t {
625 if y == x {
626 return Ok((t.clone(), Subst::default()));
628 }
629 }
630 if t.free_type_vars().contains(x) {
632 return Err(format!(" Variable '{x}' occurs in '{t}'"));
633 }
634 Ok((t.clone(), Subst(BTreeMap::from([(x.clone(), t.clone())]))))
635 }
636 (Type::Arr(l1, r1), Type::Arr(l2, r2)) => {
637 let (t1, s1) = unify(l1, l2)?;
638 let (t2, s2) = unify(&r1.subst(&s1), &r2.subst(&s1))?;
639 Ok((Type::Arr(Box::new(t1), Box::new(t2)), s2.compose(s1)))
640 }
641 (Type::Ctr(name1, ts1), Type::Ctr(name2, ts2)) if name1 == name2 && ts1.len() == ts2.len() => {
642 let mut s = Subst::default();
643 let mut ts = vec![];
644 for (t1, t2) in ts1.iter().zip(ts2.iter()) {
645 let (t, s2) = unify(t1, t2)?;
646 ts.push(t);
647 s = s.compose(s2);
648 }
649 Ok((Type::Ctr(name1.clone(), ts), s))
650 }
651 (Type::Tup(els1), Type::Tup(els2)) if els1.len() == els2.len() => {
652 let mut s = Subst::default();
653 let mut ts = vec![];
654 for (t1, t2) in els1.iter().zip(els2.iter()) {
655 let (t, s2) = unify(t1, t2)?;
656 ts.push(t);
657 s = s.compose(s2);
658 }
659 Ok((Type::Tup(ts), s))
660 }
661 t @ ((Type::U24, Type::U24)
662 | (Type::F24, Type::F24)
663 | (Type::I24, Type::I24)
664 | (Type::None, Type::None)) => Ok((t.0.clone(), Subst::default())),
665 (Type::Number(t1), Type::Number(t2)) => {
666 let (t, s) = unify(t1, t2)?;
667 Ok((Type::Number(Box::new(t)), s))
668 }
669 (Type::Number(tn), Type::Integer(ti)) | (Type::Integer(ti), Type::Number(tn)) => {
670 let (t, s) = unify(ti, tn)?;
671 Ok((Type::Integer(Box::new(t)), s))
672 }
673 (Type::Integer(t1), Type::Integer(t2)) => {
674 let (t, s) = unify(t1, t2)?;
675 Ok((Type::Integer(Box::new(t)), s))
676 }
677 (Type::Number(t1) | Type::Integer(t1), t2 @ (Type::U24 | Type::I24 | Type::F24))
678 | (t2 @ (Type::U24 | Type::I24 | Type::F24), Type::Number(t1) | Type::Integer(t1)) => {
679 let (t, s) = unify(t1, t2)?;
680 Ok((t, s))
681 }
682
683 (Type::Any, t) | (t, Type::Any) => {
684 let mut s = Subst::default();
685 for child in t.children() {
687 let (_, s2) = unify(&Type::Any, child)?;
688 s = s2.compose(s);
689 }
690 Ok((Type::Any, s))
691 }
692
693 _ => Err(String::new()),
694 })
695}
696
697fn specialize(inf: &Type, ann: &Type) -> Result<Type, String> {
705 fn merge_specialization(inf: &Type, exp: &Type, s: &mut Subst) -> Result<Type, String> {
706 maybe_grow(|| match (inf, exp) {
707 (t, Type::Hole) => Ok(t.clone()),
709 (Type::Hole, _) => unreachable!("Hole should never appear in the inferred type"),
710
711 (_inf, Type::Any) => Ok(Type::Any),
712 (Type::Any, exp) => Ok(exp.clone()),
713
714 (Type::Var(x), new) => {
715 if let Some(old) = s.0.get(x) {
716 if old == new {
717 Ok(new.clone())
718 } else {
719 Err(format!(" Inferred type variable '{x}' must be both '{old}' and '{new}'"))
720 }
721 } else {
722 s.0.insert(x.clone(), new.clone());
723 Ok(new.clone())
724 }
725 }
726
727 (Type::Arr(l1, r1), Type::Arr(l2, r2)) => {
728 let l = merge_specialization(l1, l2, s)?;
729 let r = merge_specialization(r1, r2, s)?;
730 Ok(Type::Arr(Box::new(l), Box::new(r)))
731 }
732 (Type::Ctr(name1, ts1), Type::Ctr(name2, ts2)) if name1 == name2 && ts1.len() == ts2.len() => {
733 let mut ts = vec![];
734 for (t1, t2) in ts1.iter().zip(ts2.iter()) {
735 let t = merge_specialization(t1, t2, s)?;
736 ts.push(t);
737 }
738 Ok(Type::Ctr(name1.clone(), ts))
739 }
740 (Type::Tup(ts1), Type::Tup(ts2)) if ts1.len() == ts2.len() => {
741 let mut ts = vec![];
742 for (t1, t2) in ts1.iter().zip(ts2.iter()) {
743 let t = merge_specialization(t1, t2, s)?;
744 ts.push(t);
745 }
746 Ok(Type::Tup(ts))
747 }
748 (Type::Number(t1), Type::Number(t2)) => Ok(Type::Number(Box::new(merge_specialization(t1, t2, s)?))),
749 (Type::Integer(t1), Type::Integer(t2)) => Ok(Type::Integer(Box::new(merge_specialization(t1, t2, s)?))),
750 (Type::U24, Type::U24) | (Type::F24, Type::F24) | (Type::I24, Type::I24) | (Type::None, Type::None) => {
751 Ok(inf.clone())
752 }
753 _ => Err(String::new()),
754 })
755 }
756
757 let var_gen = &mut VarGen::default();
760 let inf2 = inf.generalize(&TypeEnv::default()).instantiate(var_gen);
761 let ann2 = ann.generalize(&TypeEnv::default()).instantiate(var_gen);
762
763 let (t, s) = unify(&inf2, &ann2)
764 .map_err(|e| format!("Type Error: Expected function type '{ann}' but found '{inf}'.{e}"))?;
765 let t = t.subst(&s);
766
767 let mut merge_s = Subst::default();
770 let t2 = merge_specialization(&t, ann, &mut merge_s).map_err(|e| {
771 format!("Type Error: Annotated type '{ann}' is not a subtype of inferred type '{inf2}'.{e}")
772 })?;
773
774 Ok(t2.subst(&merge_s))
775}
776
777impl std::fmt::Display for Subst {
778 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
779 writeln!(f, "Subst {{")?;
780 for (x, y) in &self.0 {
781 writeln!(f, " {x} => {y},")?;
782 }
783 write!(f, "}}")
784 }
785}