Skip to main content

bend/fun/check/
type_check.rs

1//! Optional Hindley-Milner-like type system.
2//!
3//! Based on https://github.com/developedby/algorithm-w-rs
4//! and https://github.com/mgrabmueller/AlgorithmW.
5use crate::{
6  diagnostics::Diagnostics,
7  fun::{num_to_name, Adt, Book, Ctx, FanKind, MatchRule, Name, Num, Op, Pattern, Tag, Term, Type},
8  maybe_grow,
9};
10use std::collections::{BTreeMap, BTreeSet, HashMap};
11
12impl Ctx<'_> {
13  pub fn type_check(&mut self) -> Result<(), Diagnostics> {
14    let types = infer_book(self.book, &mut self.info)?;
15
16    for def in self.book.defs.values_mut() {
17      def.typ = types[&def.name].instantiate(&mut VarGen::default());
18    }
19
20    Ok(())
21  }
22}
23
24type ProgramTypes = HashMap<Name, Scheme>;
25
26/// A type scheme, aka a polymorphic type.
27#[derive(Clone, Debug)]
28struct Scheme(Vec<Name>, Type);
29
30/// A finite mapping from type variables to types.
31#[derive(Clone, Default, Debug)]
32struct Subst(BTreeMap<Name, Type>);
33
34/// A mapping from term variables to type schemes.
35#[derive(Clone, Default, Debug)]
36struct TypeEnv(BTreeMap<Name, Scheme>);
37
38/// Variable generator for type variables.
39#[derive(Default)]
40struct VarGen(usize);
41
42/// Topologically ordered set of mutually recursive groups of functions.
43struct RecGroups(Vec<Vec<Name>>);
44
45/* Implementations */
46
47impl Type {
48  fn free_type_vars(&self) -> BTreeSet<Name> {
49    maybe_grow(|| match self {
50      Type::Var(x) => BTreeSet::from([x.clone()]),
51      Type::Ctr(_, ts) | Type::Tup(ts) => ts.iter().flat_map(|t| t.free_type_vars()).collect(),
52      Type::Arr(t1, t2) => t1.free_type_vars().union(&t2.free_type_vars()).cloned().collect(),
53      Type::Number(t) | Type::Integer(t) => t.free_type_vars(),
54      Type::U24 | Type::F24 | Type::I24 | Type::None | Type::Any | Type::Hole => BTreeSet::new(),
55    })
56  }
57
58  fn subst(&self, subst: &Subst) -> Type {
59    maybe_grow(|| match self {
60      Type::Var(nam) => match subst.0.get(nam) {
61        Some(new) => new.clone(),
62        None => self.clone(),
63      },
64      Type::Ctr(name, ts) => Type::Ctr(name.clone(), ts.iter().map(|t| t.subst(subst)).collect()),
65      Type::Arr(t1, t2) => Type::Arr(Box::new(t1.subst(subst)), Box::new(t2.subst(subst))),
66      Type::Tup(els) => Type::Tup(els.iter().map(|t| t.subst(subst)).collect()),
67      Type::Number(t) => Type::Number(Box::new(t.subst(subst))),
68      Type::Integer(t) => Type::Integer(Box::new(t.subst(subst))),
69      t @ (Type::U24 | Type::F24 | Type::I24 | Type::None | Type::Any | Type::Hole) => t.clone(),
70    })
71  }
72
73  /// Converts a monomorphic type into a type scheme by abstracting
74  /// over the type variables free in `t`, but not free in the type
75  /// environment.
76  fn generalize(&self, env: &TypeEnv) -> Scheme {
77    let vars_env = env.free_type_vars();
78    let vars_t = self.free_type_vars();
79    let vars = vars_t.difference(&vars_env).cloned().collect();
80    Scheme(vars, self.clone())
81  }
82}
83
84impl Scheme {
85  fn free_type_vars(&self) -> BTreeSet<Name> {
86    let vars = self.1.free_type_vars();
87    let bound_vars = self.0.iter().cloned().collect();
88    vars.difference(&bound_vars).cloned().collect()
89  }
90
91  fn subst(&self, subst: &Subst) -> Scheme {
92    let mut subst = subst.clone();
93    for x in self.0.iter() {
94      subst.0.remove(x);
95    }
96    let t = self.1.subst(&subst);
97    Scheme(self.0.clone(), t)
98  }
99
100  /// Converts a type scheme into a monomorphic type by assigning
101  /// fresh type variables to each variable bound by the scheme.
102  fn instantiate(&self, var_gen: &mut VarGen) -> Type {
103    let new_vars = self.0.iter().map(|_| var_gen.fresh());
104    let subst = Subst(self.0.iter().cloned().zip(new_vars).collect());
105    self.1.subst(&subst)
106  }
107}
108
109impl Subst {
110  /// Compose two substitutions.
111  ///
112  /// Applies the first substitution to the second, and then inserts the result into the first.
113  fn compose(mut self, other: Subst) -> Subst {
114    let other = other.0.into_iter().map(|(x, t)| (x, t.subst(&self))).collect::<Vec<_>>();
115    self.0.extend(other);
116    self
117  }
118}
119
120impl TypeEnv {
121  fn free_type_vars(&self) -> BTreeSet<Name> {
122    let mut vars = BTreeSet::new();
123    for scheme in self.0.values() {
124      let scheme_vars = scheme.free_type_vars();
125      vars = vars.union(&scheme_vars).cloned().collect();
126    }
127    vars
128  }
129
130  fn subst(&self, subst: &Subst) -> TypeEnv {
131    let env = self.0.iter().map(|(x, scheme)| (x.clone(), scheme.subst(subst))).collect();
132    TypeEnv(env)
133  }
134
135  fn insert(&mut self, name: Name, scheme: Scheme) {
136    self.0.insert(name, scheme);
137  }
138
139  fn add_binds<'a>(
140    &mut self,
141    bnd: impl IntoIterator<Item = (&'a Option<Name>, Scheme)>,
142  ) -> Vec<(Name, Option<Scheme>)> {
143    let mut old_bnd = vec![];
144    for (name, scheme) in bnd {
145      if let Some(name) = name {
146        let old = self.0.insert(name.clone(), scheme);
147        old_bnd.push((name.clone(), old));
148      }
149    }
150    old_bnd
151  }
152
153  fn pop_binds(&mut self, old_bnd: Vec<(Name, Option<Scheme>)>) {
154    for (name, scheme) in old_bnd {
155      if let Some(scheme) = scheme {
156        self.0.insert(name, scheme);
157      }
158    }
159  }
160}
161
162impl VarGen {
163  fn fresh(&mut self) -> Type {
164    let x = self.fresh_name();
165    Type::Var(x)
166  }
167
168  fn fresh_name(&mut self) -> Name {
169    let x = num_to_name(self.0 as u64);
170    self.0 += 1;
171    Name::new(x)
172  }
173}
174
175impl RecGroups {
176  fn from_book(book: &Book) -> RecGroups {
177    type DependencyGraph<'a> = BTreeMap<&'a Name, BTreeSet<&'a Name>>;
178
179    fn collect_dependencies<'a>(
180      term: &'a Term,
181      book: &'a Book,
182      scope: &mut Vec<Name>,
183      deps: &mut BTreeSet<&'a Name>,
184    ) {
185      if let Term::Ref { nam } = term {
186        if book.ctrs.contains_key(nam) || book.hvm_defs.contains_key(nam) || !book.defs[nam].check {
187          // Don't infer types for constructors or unchecked functions
188        } else {
189          deps.insert(nam);
190        }
191      }
192      for (child, binds) in term.children_with_binds() {
193        scope.extend(binds.clone().flatten().cloned());
194        collect_dependencies(child, book, scope, deps);
195        scope.truncate(scope.len() - binds.flatten().count());
196      }
197    }
198
199    /// Tarjan's algorithm for finding strongly connected components.
200    fn strong_connect<'a>(
201      v: &'a Name,
202      deps: &DependencyGraph<'a>,
203      index: &mut usize,
204      index_map: &mut BTreeMap<&'a Name, usize>,
205      low_link: &mut BTreeMap<&'a Name, usize>,
206      stack: &mut Vec<&'a Name>,
207      components: &mut Vec<BTreeSet<Name>>,
208    ) {
209      maybe_grow(|| {
210        index_map.insert(v, *index);
211        low_link.insert(v, *index);
212        *index += 1;
213        stack.push(v);
214
215        if let Some(neighbors) = deps.get(v) {
216          for w in neighbors {
217            if !index_map.contains_key(w) {
218              // Successor w has not yet been visited, recurse on it.
219              strong_connect(w, deps, index, index_map, low_link, stack, components);
220              low_link.insert(v, low_link[v].min(low_link[w]));
221            } else if stack.contains(w) {
222              // Successor w is in stack S and hence in the current SCC.
223              low_link.insert(v, low_link[v].min(index_map[w]));
224            } else {
225              // If w is not on stack, then (v, w) is an edge pointing
226              // to an SCC already found and must be ignored.
227            }
228          }
229        }
230
231        // If v is a root node, pop the stack and generate an SCC.
232        if low_link[v] == index_map[v] {
233          let mut component = BTreeSet::new();
234          while let Some(w) = stack.pop() {
235            component.insert(w.clone());
236            if w == v {
237              break;
238            }
239          }
240          components.push(component);
241        }
242      })
243    }
244
245    // Build the dependency graph
246    let mut deps = DependencyGraph::default();
247    for (name, def) in &book.defs {
248      if book.ctrs.contains_key(name) || !def.check {
249        // Don't infer types for constructors or unchecked functions
250        continue;
251      }
252      let mut fn_deps = Default::default();
253      collect_dependencies(&def.rule().body, book, &mut vec![], &mut fn_deps);
254      deps.insert(name, fn_deps);
255    }
256
257    let mut index = 0;
258    let mut stack = Vec::new();
259    let mut index_map = BTreeMap::new();
260    let mut low_link = BTreeMap::new();
261    let mut components = Vec::new();
262    for name in deps.keys() {
263      if !index_map.contains_key(name) {
264        strong_connect(name, &deps, &mut index, &mut index_map, &mut low_link, &mut stack, &mut components);
265      }
266    }
267    let components = components.into_iter().map(|x| x.into_iter().collect()).collect();
268    RecGroups(components)
269  }
270}
271
272/* Inference, unification and type checking */
273fn infer_book(book: &Book, diags: &mut Diagnostics) -> Result<ProgramTypes, Diagnostics> {
274  let groups = RecGroups::from_book(book);
275  let mut env = TypeEnv::default();
276  // Note: We store the inferred and generalized types in a separate
277  // environment, to avoid unnecessary cloning (since no immutable data).
278  let mut types = ProgramTypes::default();
279
280  // Add the constructors to the environment.
281  for adt in book.adts.values() {
282    for ctr in adt.ctrs.values() {
283      types.insert(ctr.name.clone(), ctr.typ.generalize(&TypeEnv::default()));
284    }
285  }
286  // Add the types of unchecked functions to the environment.
287  for def in book.defs.values() {
288    if !def.check {
289      types.insert(def.name.clone(), def.typ.generalize(&TypeEnv::default()));
290    }
291  }
292  // Add the types of hvm functions to the environment.
293  for def in book.hvm_defs.values() {
294    types.insert(def.name.clone(), def.typ.generalize(&TypeEnv::default()));
295  }
296
297  // Infer the types of regular functions.
298  for group in &groups.0 {
299    infer_group(&mut env, book, group, &mut types, diags)?;
300  }
301  Ok(types)
302}
303
304fn infer_group(
305  env: &mut TypeEnv,
306  book: &Book,
307  group: &[Name],
308  types: &mut ProgramTypes,
309  diags: &mut Diagnostics,
310) -> Result<(), Diagnostics> {
311  let var_gen = &mut VarGen::default();
312  // Generate fresh type variables for each function in the group.
313  let tvs = group.iter().map(|_| var_gen.fresh()).collect::<Vec<_>>();
314  for (name, tv) in group.iter().zip(tvs.iter()) {
315    env.insert(name.clone(), Scheme(vec![], tv.clone()));
316  }
317
318  // Infer the types of the functions in the group.
319  let mut ss = vec![];
320  let mut inf_ts = vec![];
321  let mut exp_ts = vec![];
322  for name in group {
323    let def = &book.defs[name];
324    let (s, t) = infer(env, book, types, &def.rule().body, var_gen).map_err(|e| {
325      diags.add_function_error(e, name.clone(), def.source.clone());
326      std::mem::take(diags)
327    })?;
328    let t = t.subst(&s);
329    ss.push(s);
330    inf_ts.push(t);
331    exp_ts.push(&def.typ);
332  }
333
334  // Remove the type variables of the group from the environment.
335  // This avoids cloning of already generalized types.
336  for name in group.iter() {
337    env.0.remove(name);
338  }
339
340  // Unify the inferred body with the corresponding type variable.
341  let mut s = ss.into_iter().fold(Subst::default(), |acc, s| acc.compose(s));
342  let mut ts = vec![];
343  for ((bod_t, tv), nam) in inf_ts.into_iter().zip(tvs.iter()).zip(group.iter()) {
344    let (t, s2) = unify_term(&tv.subst(&s), &bod_t, &book.defs[nam].rule().body)?;
345    ts.push(t);
346    s = s.compose(s2);
347  }
348  let ts = ts.into_iter().map(|t| t.subst(&s)).collect::<Vec<_>>();
349
350  // Specialize against the expected type, then generalize and store.
351  for ((name, exp_t), inf_t) in group.iter().zip(exp_ts.iter()).zip(ts.iter()) {
352    let t = specialize(inf_t, exp_t).map_err(|e| {
353      diags.add_function_error(e, name.clone(), book.defs[name].source.clone());
354      std::mem::take(diags)
355    })?;
356    types.insert(name.clone(), t.generalize(&TypeEnv::default()));
357  }
358
359  diags.fatal(())
360}
361
362/// Infer the type of a term in the given environment.
363///
364/// The type environment must contain bindings for all the free variables of the term.
365///
366/// The returned substitution records the type constraints imposed on type variables by the term.
367/// The returned type is the type of the term.
368fn infer(
369  env: &mut TypeEnv,
370  book: &Book,
371  types: &ProgramTypes,
372  term: &Term,
373  var_gen: &mut VarGen,
374) -> Result<(Subst, Type), String> {
375  let res = maybe_grow(|| match term {
376    Term::Var { nam } | Term::Ref { nam } => {
377      if let Some(scheme) = env.0.get(nam) {
378        Ok::<_, String>((Subst::default(), scheme.instantiate(var_gen)))
379      } else if let Some(scheme) = types.get(nam) {
380        Ok((Subst::default(), scheme.instantiate(var_gen)))
381      } else {
382        unreachable!("unbound name '{}'", nam)
383      }
384    }
385    Term::Lam { tag: Tag::Static, pat, bod } => match pat.as_ref() {
386      Pattern::Var(nam) => {
387        let tv = var_gen.fresh();
388        let old_bnd = env.add_binds([(nam, Scheme(vec![], tv.clone()))]);
389        let (s, bod_t) = infer(env, book, types, bod, var_gen)?;
390        env.pop_binds(old_bnd);
391        let var_t = tv.subst(&s);
392        Ok((s, Type::Arr(Box::new(var_t), Box::new(bod_t))))
393      }
394      _ => unreachable!("{}", term),
395    },
396    Term::App { tag: Tag::Static, fun, arg } => {
397      let (s1, fun_t) = infer(env, book, types, fun, var_gen)?;
398      let (s2, arg_t) = infer(&mut env.subst(&s1), book, types, arg, var_gen)?;
399      let app_t = var_gen.fresh();
400      let (_, s3) = unify_term(&fun_t.subst(&s2), &Type::Arr(Box::new(arg_t), Box::new(app_t.clone())), fun)?;
401      let t = app_t.subst(&s3);
402      Ok((s3.compose(s2).compose(s1), t))
403    }
404    Term::Let { pat, val, nxt } => match pat.as_ref() {
405      Pattern::Var(nam) => {
406        let (s1, val_t) = infer(env, book, types, val, var_gen)?;
407        let old_bnd = env.add_binds([(nam, val_t.generalize(&env.subst(&s1)))]);
408        let (s2, nxt_t) = infer(&mut env.subst(&s1), book, types, nxt, var_gen)?;
409        env.pop_binds(old_bnd);
410        Ok((s2.compose(s1), nxt_t))
411      }
412      Pattern::Fan(FanKind::Tup, Tag::Static, _) => {
413        // Tuple elimination behaves like pattern matching.
414        // Variables from tuple patterns don't get generalized.
415        debug_assert!(!(pat.has_unscoped() || pat.has_nested()));
416        let (s1, val_t) = infer(env, book, types, val, var_gen)?;
417
418        let tvs = pat.binds().map(|_| var_gen.fresh()).collect::<Vec<_>>();
419        let old_bnd = env.add_binds(pat.binds().zip(tvs.iter().map(|tv| Scheme(vec![], tv.clone()))));
420        let (s2, nxt_t) = infer(&mut env.subst(&s1), book, types, nxt, var_gen)?;
421        env.pop_binds(old_bnd);
422        let tvs = tvs.into_iter().map(|tv| tv.subst(&s2)).collect::<Vec<_>>();
423        let (_, s3) = unify_term(&val_t, &Type::Tup(tvs), val)?;
424        Ok((s3.compose(s2).compose(s1), nxt_t))
425      }
426      Pattern::Fan(FanKind::Dup, Tag::Auto, _) => {
427        // We pretend that sups don't exist and dups don't collide.
428        // All variables must have the same type as the body of the dup.
429        debug_assert!(!(pat.has_unscoped() || pat.has_nested()));
430        let (s1, mut val_t) = infer(env, book, types, val, var_gen)?;
431        let tvs = pat.binds().map(|_| var_gen.fresh()).collect::<Vec<_>>();
432        let old_bnd = env.add_binds(pat.binds().zip(tvs.iter().map(|tv| Scheme(vec![], tv.clone()))));
433        let (mut s2, nxt_t) = infer(&mut env.subst(&s1), book, types, nxt, var_gen)?;
434        env.pop_binds(old_bnd);
435        for tv in tvs {
436          let (val_t_, s) = unify_term(&val_t, &tv.subst(&s2), val)?;
437          val_t = val_t_;
438          s2 = s2.compose(s);
439        }
440        Ok((s2.compose(s1), nxt_t))
441      }
442      _ => unreachable!(),
443    },
444
445    Term::Mat { bnd: _, arg, with_bnd: _, with_arg: _, arms } => {
446      // Infer type of the scrutinee
447      let (s1, t1) = infer(env, book, types, arg, var_gen)?;
448
449      // Instantiate the expected type of the scrutinee
450      let adt_name = book.ctrs.get(arms[0].0.as_ref().unwrap()).unwrap();
451      let adt = &book.adts[adt_name];
452      let (adt_s, adt_t) = instantiate_adt(adt, var_gen)?;
453
454      // For each case, infer the types and unify them all.
455      // Unify the inferred type of the destructured fields with the
456      // expected from what we inferred from the scrutinee.
457      let (s2, nxt_t) = infer_match_cases(env.subst(&s1), book, types, adt, arms, &adt_s, var_gen)?;
458
459      // Unify the inferred type with the expected type
460      let (_, s3) = unify_term(&t1, &adt_t.subst(&s2), arg)?;
461      Ok((s3.compose(s2).compose(s1), nxt_t))
462    }
463
464    Term::Num { val } => {
465      let t = match val {
466        Num::U24(_) => Type::U24,
467        Num::I24(_) => Type::I24,
468        Num::F24(_) => Type::F24,
469      };
470      Ok((Subst::default(), t))
471    }
472    Term::Oper { opr, fst, snd } => {
473      let (s1, t1) = infer(env, book, types, fst, var_gen)?;
474      let (s2, t2) = infer(&mut env.subst(&s1), book, types, snd, var_gen)?;
475      let (t2, s3) = unify_term(&t2.subst(&s1), &t1.subst(&s2), term)?;
476      let s_args = s3.compose(s2).compose(s1);
477      let t_args = t2.subst(&s_args);
478      // Check numeric type matches the operation
479      let tv = var_gen.fresh();
480      let (t_opr, s_opr) = match opr {
481        // Any numeric type
482        Op::ADD | Op::SUB | Op::MUL | Op::DIV => {
483          unify_term(&t_args, &Type::Number(Box::new(tv.clone())), term)?
484        }
485        Op::EQ | Op::NEQ | Op::LT | Op::GT | Op::GE | Op::LE => {
486          let (_, s) = unify_term(&t_args, &Type::Number(Box::new(tv.clone())), term)?;
487          (Type::U24, s)
488        }
489        // Integers
490        Op::REM | Op::AND | Op::OR | Op::XOR | Op::SHL | Op::SHR => {
491          unify_term(&t_args, &Type::Integer(Box::new(tv.clone())), term)?
492        }
493        // Floating
494        Op::POW => unify_term(&t_args, &Type::F24, term)?,
495      };
496      let t = t_opr.subst(&s_opr);
497      Ok((s_opr.compose(s_args), t))
498    }
499    Term::Swt { bnd: _, arg, with_bnd: _, with_arg: _, pred, arms } => {
500      let (s1, t1) = infer(env, book, types, arg, var_gen)?;
501      let (_, s2) = unify_term(&t1, &Type::U24, arg)?;
502      let s_arg = s2.compose(s1);
503      let mut env = env.subst(&s_arg);
504
505      let mut ss_nums = vec![];
506      let mut ts_nums = vec![];
507      for arm in arms.iter().rev().skip(1) {
508        let (s, t) = infer(&mut env, book, types, arm, var_gen)?;
509        env = env.subst(&s);
510        ss_nums.push(s);
511        ts_nums.push(t);
512      }
513      let old_bnd = env.add_binds([(pred, Scheme(vec![], Type::U24))]);
514      let (s_succ, t_succ) = infer(&mut env, book, types, &arms[1], var_gen)?;
515      env.pop_binds(old_bnd);
516
517      let s_arms = ss_nums.into_iter().fold(s_succ, |acc, s| acc.compose(s));
518      let mut t_swt = t_succ;
519      let mut s_swt = Subst::default();
520      for t_num in ts_nums {
521        let (t, s) = unify_term(&t_swt, &t_num, term)?;
522        t_swt = t;
523        s_swt = s.compose(s_swt);
524      }
525
526      let s = s_swt.compose(s_arms).compose(s_arg);
527      let t = t_swt.subst(&s);
528      Ok((s, t))
529    }
530
531    Term::Fan { fan: FanKind::Tup, tag: Tag::Static, els } => {
532      let res = els.iter().map(|el| infer(env, book, types, el, var_gen)).collect::<Result<Vec<_>, _>>()?;
533      let (ss, ts): (Vec<Subst>, Vec<Type>) = res.into_iter().unzip();
534      let t = Type::Tup(ts);
535      let s = ss.into_iter().fold(Subst::default(), |acc, s| acc.compose(s));
536      Ok((s, t))
537    }
538    Term::Era => Ok((Subst::default(), Type::None)),
539    Term::Fan { .. } | Term::Lam { tag: _, .. } | Term::App { tag: _, .. } | Term::Link { .. } => {
540      unreachable!("'{term}' while type checking. Should never occur in checked functions")
541    }
542    Term::Use { .. }
543    | Term::With { .. }
544    | Term::Ask { .. }
545    | Term::Nat { .. }
546    | Term::Str { .. }
547    | Term::List { .. }
548    | Term::Fold { .. }
549    | Term::Bend { .. }
550    | Term::Open { .. }
551    | Term::Def { .. }
552    | Term::Err => unreachable!("'{term}' while type checking. Should have been removed in earlier pass"),
553  })?;
554  Ok(res)
555}
556
557/// Instantiates the type constructor of an ADT, also returning the
558/// ADT var to instantiated var substitution, to be used when
559/// instantiating the types of the fields of the eliminated constructors.
560fn instantiate_adt(adt: &Adt, var_gen: &mut VarGen) -> Result<(Subst, Type), String> {
561  let tvs = adt.vars.iter().map(|_| var_gen.fresh());
562  let s = Subst(adt.vars.iter().zip(tvs).map(|(x, t)| (x.clone(), t)).collect());
563  let t = Type::Ctr(adt.name.clone(), adt.vars.iter().cloned().map(Type::Var).collect());
564  let t = t.subst(&s);
565  Ok((s, t))
566}
567
568fn infer_match_cases(
569  mut env: TypeEnv,
570  book: &Book,
571  types: &ProgramTypes,
572  adt: &Adt,
573  arms: &[MatchRule],
574  adt_s: &Subst,
575  var_gen: &mut VarGen,
576) -> Result<(Subst, Type), String> {
577  maybe_grow(|| {
578    if let Some(((ctr_nam, vars, bod), rest)) = arms.split_first() {
579      let ctr = &adt.ctrs[ctr_nam.as_ref().unwrap()];
580      // One fresh var per field, we later unify with the expected type.
581      let tvs = vars.iter().map(|_| var_gen.fresh()).collect::<Vec<_>>();
582
583      // Infer the body and unify the inferred field types with the expected.
584      let old_bnd = env.add_binds(vars.iter().zip(tvs.iter().map(|tv| Scheme(vec![], tv.clone()))));
585      let (s1, t1) = infer(&mut env, book, types, bod, var_gen)?;
586      env.pop_binds(old_bnd);
587      let inf_ts = tvs.into_iter().map(|tv| tv.subst(&s1)).collect::<Vec<_>>();
588      let exp_ts = ctr.fields.iter().map(|f| f.typ.subst(adt_s)).collect::<Vec<_>>();
589      let s2 = unify_fields(inf_ts.iter().zip(exp_ts.iter()), bod)?;
590
591      // Recurse and unify with the other arms.
592      let s = s2.compose(s1);
593      let (s_rest, t_rest) = infer_match_cases(env.subst(&s), book, types, adt, rest, adt_s, var_gen)?;
594      let (t_final, s_final) = unify_term(&t1.subst(&s), &t_rest, bod)?;
595
596      Ok((s_final.compose(s_rest).compose(s), t_final))
597    } else {
598      Ok((Subst::default(), var_gen.fresh()))
599    }
600  })
601}
602
603fn unify_fields<'a>(ts: impl Iterator<Item = (&'a Type, &'a Type)>, ctx: &Term) -> Result<Subst, String> {
604  let ss = ts.map(|(inf, exp)| unify_term(inf, exp, ctx)).collect::<Result<Vec<_>, _>>()?;
605  let mut s = Subst::default();
606  for (_, s2) in ss.into_iter().rev() {
607    s = s.compose(s2);
608  }
609  Ok(s)
610}
611
612fn unify_term(t1: &Type, t2: &Type, ctx: &Term) -> Result<(Type, Subst), String> {
613  match unify(t1, t2) {
614    Ok((t, s)) => Ok((t, s)),
615    Err(msg) => Err(format!("In {ctx}:\n  Can't unify '{t1}' and '{t2}'.{msg}")),
616  }
617}
618
619fn unify(t1: &Type, t2: &Type) -> Result<(Type, Subst), String> {
620  maybe_grow(|| match (t1, t2) {
621    (t, Type::Hole) | (Type::Hole, t) => Ok((t.clone(), Subst::default())),
622    (t, Type::Var(x)) | (Type::Var(x), t) => {
623      // Try to bind variable `x` to `t`
624      if let Type::Var(y) = t {
625        if y == x {
626          // Don't bind a variable to itself
627          return Ok((t.clone(), Subst::default()));
628        }
629      }
630      // Occurs check
631      if t.free_type_vars().contains(x) {
632        return Err(format!(" Variable '{x}' occurs in '{t}'"));
633      }
634      Ok((t.clone(), Subst(BTreeMap::from([(x.clone(), t.clone())]))))
635    }
636    (Type::Arr(l1, r1), Type::Arr(l2, r2)) => {
637      let (t1, s1) = unify(l1, l2)?;
638      let (t2, s2) = unify(&r1.subst(&s1), &r2.subst(&s1))?;
639      Ok((Type::Arr(Box::new(t1), Box::new(t2)), s2.compose(s1)))
640    }
641    (Type::Ctr(name1, ts1), Type::Ctr(name2, ts2)) if name1 == name2 && ts1.len() == ts2.len() => {
642      let mut s = Subst::default();
643      let mut ts = vec![];
644      for (t1, t2) in ts1.iter().zip(ts2.iter()) {
645        let (t, s2) = unify(t1, t2)?;
646        ts.push(t);
647        s = s.compose(s2);
648      }
649      Ok((Type::Ctr(name1.clone(), ts), s))
650    }
651    (Type::Tup(els1), Type::Tup(els2)) if els1.len() == els2.len() => {
652      let mut s = Subst::default();
653      let mut ts = vec![];
654      for (t1, t2) in els1.iter().zip(els2.iter()) {
655        let (t, s2) = unify(t1, t2)?;
656        ts.push(t);
657        s = s.compose(s2);
658      }
659      Ok((Type::Tup(ts), s))
660    }
661    t @ ((Type::U24, Type::U24)
662    | (Type::F24, Type::F24)
663    | (Type::I24, Type::I24)
664    | (Type::None, Type::None)) => Ok((t.0.clone(), Subst::default())),
665    (Type::Number(t1), Type::Number(t2)) => {
666      let (t, s) = unify(t1, t2)?;
667      Ok((Type::Number(Box::new(t)), s))
668    }
669    (Type::Number(tn), Type::Integer(ti)) | (Type::Integer(ti), Type::Number(tn)) => {
670      let (t, s) = unify(ti, tn)?;
671      Ok((Type::Integer(Box::new(t)), s))
672    }
673    (Type::Integer(t1), Type::Integer(t2)) => {
674      let (t, s) = unify(t1, t2)?;
675      Ok((Type::Integer(Box::new(t)), s))
676    }
677    (Type::Number(t1) | Type::Integer(t1), t2 @ (Type::U24 | Type::I24 | Type::F24))
678    | (t2 @ (Type::U24 | Type::I24 | Type::F24), Type::Number(t1) | Type::Integer(t1)) => {
679      let (t, s) = unify(t1, t2)?;
680      Ok((t, s))
681    }
682
683    (Type::Any, t) | (t, Type::Any) => {
684      let mut s = Subst::default();
685      // Recurse to assign variables to `Any` as well
686      for child in t.children() {
687        let (_, s2) = unify(&Type::Any, child)?;
688        s = s2.compose(s);
689      }
690      Ok((Type::Any, s))
691    }
692
693    _ => Err(String::new()),
694  })
695}
696
697/// Specializes the inferred type against the type annotation.
698/// This way, the annotation can be less general than the inferred type.
699///
700/// It also forces inferred 'Any' to the annotated, inferred types to
701/// annotated 'Any' and fills 'Hole' with the inferred type.
702///
703/// Errors if the first type is not a superset of the second type.
704fn specialize(inf: &Type, ann: &Type) -> Result<Type, String> {
705  fn merge_specialization(inf: &Type, exp: &Type, s: &mut Subst) -> Result<Type, String> {
706    maybe_grow(|| match (inf, exp) {
707      // These rules have to come before
708      (t, Type::Hole) => Ok(t.clone()),
709      (Type::Hole, _) => unreachable!("Hole should never appear in the inferred type"),
710
711      (_inf, Type::Any) => Ok(Type::Any),
712      (Type::Any, exp) => Ok(exp.clone()),
713
714      (Type::Var(x), new) => {
715        if let Some(old) = s.0.get(x) {
716          if old == new {
717            Ok(new.clone())
718          } else {
719            Err(format!(" Inferred type variable '{x}' must be both '{old}' and '{new}'"))
720          }
721        } else {
722          s.0.insert(x.clone(), new.clone());
723          Ok(new.clone())
724        }
725      }
726
727      (Type::Arr(l1, r1), Type::Arr(l2, r2)) => {
728        let l = merge_specialization(l1, l2, s)?;
729        let r = merge_specialization(r1, r2, s)?;
730        Ok(Type::Arr(Box::new(l), Box::new(r)))
731      }
732      (Type::Ctr(name1, ts1), Type::Ctr(name2, ts2)) if name1 == name2 && ts1.len() == ts2.len() => {
733        let mut ts = vec![];
734        for (t1, t2) in ts1.iter().zip(ts2.iter()) {
735          let t = merge_specialization(t1, t2, s)?;
736          ts.push(t);
737        }
738        Ok(Type::Ctr(name1.clone(), ts))
739      }
740      (Type::Tup(ts1), Type::Tup(ts2)) if ts1.len() == ts2.len() => {
741        let mut ts = vec![];
742        for (t1, t2) in ts1.iter().zip(ts2.iter()) {
743          let t = merge_specialization(t1, t2, s)?;
744          ts.push(t);
745        }
746        Ok(Type::Tup(ts))
747      }
748      (Type::Number(t1), Type::Number(t2)) => Ok(Type::Number(Box::new(merge_specialization(t1, t2, s)?))),
749      (Type::Integer(t1), Type::Integer(t2)) => Ok(Type::Integer(Box::new(merge_specialization(t1, t2, s)?))),
750      (Type::U24, Type::U24) | (Type::F24, Type::F24) | (Type::I24, Type::I24) | (Type::None, Type::None) => {
751        Ok(inf.clone())
752      }
753      _ => Err(String::new()),
754    })
755  }
756
757  // Refresh the variable names to avoid conflicts when unifying
758  // Names of type vars in the annotation have nothing to do with names in the inferred type.
759  let var_gen = &mut VarGen::default();
760  let inf2 = inf.generalize(&TypeEnv::default()).instantiate(var_gen);
761  let ann2 = ann.generalize(&TypeEnv::default()).instantiate(var_gen);
762
763  let (t, s) = unify(&inf2, &ann2)
764    .map_err(|e| format!("Type Error: Expected function type '{ann}' but found '{inf}'.{e}"))?;
765  let t = t.subst(&s);
766
767  // Merge the inferred specialization with the expected type.
768  // This is done to cast to/from `Any` and `_` types.
769  let mut merge_s = Subst::default();
770  let t2 = merge_specialization(&t, ann, &mut merge_s).map_err(|e| {
771    format!("Type Error: Annotated type '{ann}' is not a subtype of inferred type '{inf2}'.{e}")
772  })?;
773
774  Ok(t2.subst(&merge_s))
775}
776
777impl std::fmt::Display for Subst {
778  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
779    writeln!(f, "Subst {{")?;
780    for (x, y) in &self.0 {
781      writeln!(f, "  {x} => {y},")?;
782    }
783    write!(f, "}}")
784  }
785}