Skip to main content

rex_typesystem/
inference.rs

1use crate::{
2    error::TypeError,
3    types::{
4        AdtDecl, AdtVariant, BuiltinTypeId, Predicate, Scheme, Type, TypeConst, TypeEnv, TypeKind,
5        TypeVar, TypeVarId, TypedExpr, TypedExprKind, Types,
6    },
7    typesystem::{
8        TypeSystem, TypeVarSupply, instantiate, is_integral_literal_expr,
9        predicates_from_constraints, reject_ambiguous_scheme, type_from_annotation_expr,
10        type_from_annotation_expr_vars,
11    },
12    unification::{Subst, Unifier, compose_subst, subst_is_empty, unify},
13};
14use rex_ast::expr::{Expr, Pattern, Symbol, TypeConstraint, TypeExpr, sym};
15use rex_util::gas::GasMeter;
16use std::{
17    collections::{BTreeMap, BTreeSet},
18    sync::Arc,
19};
20
21fn dedup_preds(preds: Vec<Predicate>) -> Vec<Predicate> {
22    let mut seen = BTreeSet::new();
23    let mut out = Vec::with_capacity(preds.len());
24    for pred in preds {
25        if seen.insert(pred.clone()) {
26            out.push(pred);
27        }
28    }
29    out
30}
31
32fn is_integral_primitive(typ: &Type) -> bool {
33    matches!(
34        typ.as_ref(),
35        TypeKind::Con(TypeConst {
36            builtin_id: Some(
37                BuiltinTypeId::U8
38                    | BuiltinTypeId::U16
39                    | BuiltinTypeId::U32
40                    | BuiltinTypeId::U64
41                    | BuiltinTypeId::I8
42                    | BuiltinTypeId::I16
43                    | BuiltinTypeId::I32
44                    | BuiltinTypeId::I64
45            ),
46            ..
47        })
48    )
49}
50
51fn finalize_infer_for_public_api(
52    mut preds: Vec<Predicate>,
53    mut typ: Type,
54) -> Result<(Vec<Predicate>, Type), TypeError> {
55    let mut subst = Subst::new_sync();
56    for pred in &preds {
57        if pred.class.as_ref() == "Integral"
58            && let TypeKind::Var(tv) = pred.typ.as_ref()
59        {
60            subst = subst.insert(tv.id, Type::builtin(BuiltinTypeId::I32));
61        }
62    }
63
64    if !subst_is_empty(&subst) {
65        preds = dedup_preds(preds.apply(&subst));
66        typ = typ.apply(&subst);
67    }
68
69    for pred in &preds {
70        if pred.class.as_ref() != "Integral" {
71            continue;
72        }
73        if matches!(pred.typ.as_ref(), TypeKind::Var(_)) || is_integral_primitive(&pred.typ) {
74            continue;
75        }
76        return Err(TypeError::Unification("i32".into(), pred.typ.to_string()));
77    }
78
79    Ok((preds, typ))
80}
81
82#[derive(Clone, Debug)]
83struct KnownVariant {
84    adt: Symbol,
85    variant: Symbol,
86}
87
88type KnownVariants = BTreeMap<Symbol, KnownVariant>;
89
90fn apply_scheme_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> Scheme {
91    let preds = scheme
92        .preds
93        .iter()
94        .map(|pred| Predicate::new(pred.class.clone(), unifier.apply_type(&pred.typ)))
95        .collect();
96    let typ = unifier.apply_type(&scheme.typ);
97    Scheme::new(scheme.vars.clone(), preds, typ)
98}
99
100fn scheme_ftv_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> BTreeSet<TypeVarId> {
101    let mut ftv = unifier.apply_type(&scheme.typ).ftv();
102    for pred in &scheme.preds {
103        ftv.extend(unifier.apply_type(&pred.typ).ftv());
104    }
105    for var in &scheme.vars {
106        ftv.remove(&var.id);
107    }
108    ftv
109}
110
111fn env_ftv_with_unifier(env: &TypeEnv, unifier: &mut Unifier<'_>) -> BTreeSet<TypeVarId> {
112    let mut out = BTreeSet::new();
113    for (_name, schemes) in env.values.iter() {
114        for scheme in schemes {
115            out.extend(scheme_ftv_with_unifier(scheme, unifier));
116        }
117    }
118    out
119}
120
121fn generalize_with_unifier(
122    env: &TypeEnv,
123    preds: Vec<Predicate>,
124    typ: Type,
125    unifier: &mut Unifier<'_>,
126) -> Scheme {
127    let preds: Vec<Predicate> = preds
128        .into_iter()
129        .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
130        .collect();
131    let typ = unifier.apply_type(&typ);
132    let mut vars: Vec<TypeVar> = typ
133        .ftv()
134        .union(&preds.ftv())
135        .copied()
136        .collect::<BTreeSet<_>>()
137        .difference(&env_ftv_with_unifier(env, unifier))
138        .cloned()
139        .map(|id| TypeVar::new(id, None))
140        .collect();
141    vars.sort_by_key(|v| v.id);
142    Scheme::new(vars, preds, typ)
143}
144
145fn monomorphic_scheme_with_unifier(
146    preds: Vec<Predicate>,
147    typ: Type,
148    unifier: &mut Unifier<'_>,
149) -> Scheme {
150    let preds = dedup_preds(
151        preds
152            .into_iter()
153            .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
154            .collect(),
155    );
156    let typ = unifier.apply_type(&typ);
157    Scheme::new(vec![], preds, typ)
158}
159
160pub fn infer_typed(
161    type_system: &mut TypeSystem,
162    expr: &Expr,
163) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
164    infer_typed_inner(type_system, expr)
165}
166
167pub fn infer_typed_with_gas(
168    type_system: &mut TypeSystem,
169    expr: &Expr,
170    gas: &mut GasMeter,
171) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
172    let known = KnownVariants::new();
173    let mut unifier = Unifier::with_gas(gas, type_system.limits.max_infer_depth);
174    let (preds, t, typed) = infer_expr(
175        &mut unifier,
176        &mut type_system.supply,
177        &type_system.env,
178        &type_system.adts,
179        &known,
180        expr,
181    )
182    .map_err(|err| err.with_span(expr.span()))?;
183    let subst = unifier.into_subst();
184    let mut typed = typed.apply(&subst);
185    let mut preds = dedup_preds(preds.apply(&subst));
186    let mut t = t.apply(&subst);
187    let improve = improve_indexable(&preds)?;
188    if !subst_is_empty(&improve) {
189        typed = typed.apply(&improve);
190        preds = dedup_preds(preds.apply(&improve));
191        t = t.apply(&improve);
192    }
193    type_system.check_predicate_kinds(&preds)?;
194    Ok((typed, preds, t))
195}
196
197fn infer_typed_inner(
198    type_system: &mut TypeSystem,
199    expr: &Expr,
200) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
201    let known = KnownVariants::new();
202    let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
203    let (preds, t, typed) = infer_expr(
204        &mut unifier,
205        &mut type_system.supply,
206        &type_system.env,
207        &type_system.adts,
208        &known,
209        expr,
210    )
211    .map_err(|err| err.with_span(expr.span()))?;
212    let subst = unifier.into_subst();
213    let mut typed = typed.apply(&subst);
214    let mut preds = dedup_preds(preds.apply(&subst));
215    let mut t = t.apply(&subst);
216    let improve = improve_indexable(&preds)?;
217    if !subst_is_empty(&improve) {
218        typed = typed.apply(&improve);
219        preds = dedup_preds(preds.apply(&improve));
220        t = t.apply(&improve);
221    }
222    type_system.check_predicate_kinds(&preds)?;
223    Ok((typed, preds, t))
224}
225
226pub fn infer(
227    type_system: &mut TypeSystem,
228    expr: &Expr,
229) -> Result<(Vec<Predicate>, Type), TypeError> {
230    infer_inner(type_system, expr)
231}
232
233pub fn infer_with_gas(
234    type_system: &mut TypeSystem,
235    expr: &Expr,
236    gas: &mut GasMeter,
237) -> Result<(Vec<Predicate>, Type), TypeError> {
238    let known = KnownVariants::new();
239    let mut unifier = Unifier::with_gas(gas, type_system.limits.max_infer_depth);
240    let (preds, t) = infer_expr_type(
241        &mut unifier,
242        &mut type_system.supply,
243        &type_system.env,
244        &type_system.adts,
245        &known,
246        expr,
247    )
248    .map_err(|err| err.with_span(expr.span()))?;
249    let subst = unifier.into_subst();
250    let preds = dedup_preds(preds.apply(&subst));
251    let t = t.apply(&subst);
252    type_system.check_predicate_kinds(&preds)?;
253    finalize_infer_for_public_api(preds, t)
254}
255
256fn infer_inner(
257    type_system: &mut TypeSystem,
258    expr: &Expr,
259) -> Result<(Vec<Predicate>, Type), TypeError> {
260    let known = KnownVariants::new();
261    let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
262    let (preds, t) = infer_expr_type(
263        &mut unifier,
264        &mut type_system.supply,
265        &type_system.env,
266        &type_system.adts,
267        &known,
268        expr,
269    )
270    .map_err(|err| err.with_span(expr.span()))?;
271    let subst = unifier.into_subst();
272    let mut preds = dedup_preds(preds.apply(&subst));
273    let mut t = t.apply(&subst);
274    let improve = improve_indexable(&preds)?;
275    if !subst_is_empty(&improve) {
276        preds = dedup_preds(preds.apply(&improve));
277        t = t.apply(&improve);
278    }
279    type_system.check_predicate_kinds(&preds)?;
280    finalize_infer_for_public_api(preds, t)
281}
282
283fn improve_indexable(preds: &[Predicate]) -> Result<Subst, TypeError> {
284    let mut subst = Subst::new_sync();
285    loop {
286        let mut changed = false;
287        for pred in preds {
288            let pred = pred.apply(&subst);
289            if pred.class.as_ref() != "Indexable" {
290                continue;
291            }
292            let TypeKind::Tuple(parts) = pred.typ.as_ref() else {
293                continue;
294            };
295            if parts.len() != 2 {
296                continue;
297            }
298            let container = parts[0].clone();
299            let elem = parts[1].clone();
300            let s = indexable_elem_subst(&container, &elem)?;
301            if !subst_is_empty(&s) {
302                subst = compose_subst(s, subst);
303                changed = true;
304            }
305        }
306        if !changed {
307            break;
308        }
309    }
310    Ok(subst)
311}
312
313fn indexable_elem_subst(container: &Type, elem: &Type) -> Result<Subst, TypeError> {
314    match container.as_ref() {
315        TypeKind::App(head, arg) => match head.as_ref() {
316            TypeKind::Con(tc)
317                if matches!(
318                    tc.builtin_id,
319                    Some(BuiltinTypeId::List | BuiltinTypeId::Array)
320                ) =>
321            {
322                unify(elem, arg)
323            }
324            _ => Ok(Subst::new_sync()),
325        },
326        TypeKind::Tuple(elems) => {
327            if elems.is_empty() {
328                return Ok(Subst::new_sync());
329            }
330            let mut subst = Subst::new_sync();
331            let mut cur = elems[0].clone();
332            for ty in elems.iter().skip(1) {
333                let s_next = unify(&cur.apply(&subst), &ty.apply(&subst))?;
334                subst = compose_subst(s_next, subst);
335                cur = cur.apply(&subst);
336            }
337            let elem = elem.apply(&subst);
338            let s_elem = unify(&elem, &cur.apply(&subst))?;
339            Ok(compose_subst(s_elem, subst))
340        }
341        _ => Ok(Subst::new_sync()),
342    }
343}
344
345type LambdaChain<'a> = (
346    Vec<(Symbol, Option<TypeExpr>)>,
347    Vec<TypeConstraint>,
348    &'a Expr,
349);
350
351fn collect_lambda_chain<'a>(expr: &'a Expr) -> LambdaChain<'a> {
352    let mut params = Vec::new();
353    let mut constraints = Vec::new();
354    let mut cur = expr;
355    let mut seen_constraints = false;
356    while let Expr::Lam(_, _scope, param, ann, lam_constraints, body) = cur {
357        if !lam_constraints.is_empty() {
358            if seen_constraints {
359                break;
360            }
361            constraints = lam_constraints.clone();
362            seen_constraints = true;
363        }
364        params.push((param.name.clone(), ann.clone()));
365        cur = body.as_ref();
366    }
367    (params, constraints, cur)
368}
369
370fn collect_app_chain(expr: &Expr) -> (&Expr, Vec<&Expr>) {
371    let mut args = Vec::new();
372    let mut cur = expr;
373    while let Expr::App(_, f, x) = cur {
374        args.push(x.as_ref());
375        cur = f.as_ref();
376    }
377    args.reverse();
378    (cur, args)
379}
380
381fn narrow_overload_candidates(candidates: &[Type], arg_ty: &Type) -> Vec<Type> {
382    let mut out = Vec::new();
383    for candidate in candidates {
384        let Some((params, ret)) = decompose_fun(candidate, 1) else {
385            continue;
386        };
387        let param = &params[0];
388        if let Ok(s) = unify(param, arg_ty) {
389            out.push(ret.apply(&s));
390        }
391    }
392    out
393}
394
395fn unary_app_arg(typ: &Type, ctor_name: &str) -> Option<Type> {
396    let TypeKind::App(head, arg) = typ.as_ref() else {
397        return None;
398    };
399    let TypeKind::Con(tc) = head.as_ref() else {
400        return None;
401    };
402    (tc.name.as_ref() == ctor_name && tc.arity == 1).then(|| arg.clone())
403}
404
405fn infer_app_arg_type(
406    unifier: &mut Unifier<'_>,
407    supply: &mut TypeVarSupply,
408    env: &TypeEnv,
409    adts: &BTreeMap<Symbol, AdtDecl>,
410    known: &KnownVariants,
411    arg_hint: Option<Type>,
412    arg: &Expr,
413) -> Result<(Vec<Predicate>, Type), TypeError> {
414    match (arg_hint, arg) {
415        (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
416            infer_record_update_type_with_hint(
417                unifier,
418                supply,
419                env,
420                adts,
421                known,
422                base.as_ref(),
423                updates,
424                &arg_hint,
425            )
426        }
427        (Some(arg_hint), Expr::Dict(_, kvs))
428            if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
429        {
430            let TypeKind::Record(fields) = arg_hint.as_ref() else {
431                unreachable!("guarded by matches!")
432            };
433            let expected: BTreeMap<_, _> =
434                fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
435            let mut seen = BTreeSet::new();
436            let mut preds = Vec::new();
437            for (k, v) in kvs {
438                let expected_ty = expected
439                    .get(k)
440                    .ok_or_else(|| TypeError::UnknownField {
441                        field: k.clone(),
442                        typ: Type::record(fields.clone()).to_string(),
443                    })?
444                    .clone();
445                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
446                unifier.unify(&t1, &expected_ty)?;
447                preds.extend(p1);
448                seen.insert(k.clone());
449            }
450            for key in expected.keys() {
451                if !seen.contains(key.as_ref()) {
452                    return Err(TypeError::UnknownField {
453                        field: key.clone(),
454                        typ: Type::record(fields.clone()).to_string(),
455                    });
456                }
457            }
458            let record_ty = Type::record(
459                fields
460                    .iter()
461                    .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
462                    .collect(),
463            );
464            Ok((preds, record_ty))
465        }
466        _ => infer_expr_type(unifier, supply, env, adts, known, arg),
467    }
468}
469
470fn infer_app_arg_typed(
471    unifier: &mut Unifier<'_>,
472    supply: &mut TypeVarSupply,
473    env: &TypeEnv,
474    adts: &BTreeMap<Symbol, AdtDecl>,
475    known: &KnownVariants,
476    arg_hint: Option<Type>,
477    arg: &Expr,
478) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
479    match (arg_hint, arg) {
480        (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
481            infer_record_update_typed_with_hint(
482                unifier,
483                supply,
484                env,
485                adts,
486                known,
487                base.as_ref(),
488                updates,
489                &arg_hint,
490            )
491        }
492        (Some(arg_hint), Expr::Dict(_, kvs))
493            if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
494        {
495            let TypeKind::Record(fields) = arg_hint.as_ref() else {
496                unreachable!("guarded by matches!")
497            };
498            let mut preds = Vec::new();
499            let mut typed_kvs = BTreeMap::new();
500            let expected: BTreeMap<_, _> =
501                fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
502            for (k, v) in kvs {
503                let expected_ty = expected
504                    .get(k)
505                    .ok_or_else(|| TypeError::UnknownField {
506                        field: k.clone(),
507                        typ: Type::record(fields.clone()).to_string(),
508                    })?
509                    .clone();
510                let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
511                unifier.unify(&t1, &expected_ty)?;
512                preds.extend(p1);
513                typed_kvs.insert(k.clone(), typed_v);
514            }
515            for key in expected.keys() {
516                if !typed_kvs.contains_key(key.as_ref()) {
517                    return Err(TypeError::UnknownField {
518                        field: key.clone(),
519                        typ: Type::record(fields.clone()).to_string(),
520                    });
521                }
522            }
523            let record_ty = Type::record(
524                fields
525                    .iter()
526                    .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
527                    .collect(),
528            );
529            let typed = TypedExpr::new(record_ty.clone(), TypedExprKind::Dict(typed_kvs));
530            Ok((preds, record_ty, typed))
531        }
532        _ => infer_expr(unifier, supply, env, adts, known, arg),
533    }
534}
535
536#[allow(clippy::too_many_arguments)]
537fn infer_record_update_type_with_hint(
538    unifier: &mut Unifier<'_>,
539    supply: &mut TypeVarSupply,
540    env: &TypeEnv,
541    adts: &BTreeMap<Symbol, AdtDecl>,
542    known: &KnownVariants,
543    base: &Expr,
544    updates: &BTreeMap<Symbol, Arc<Expr>>,
545    hint_ty: &Type,
546) -> Result<(Vec<Predicate>, Type), TypeError> {
547    let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
548    unifier.unify(&t_base, hint_ty)?;
549    let base_ty = unifier.apply_type(&t_base);
550    let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
551    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
552    let (result_ty, fields) = resolve_record_update(
553        unifier,
554        supply,
555        adts,
556        &base_ty,
557        known_variant,
558        &update_fields,
559    )?;
560    let expected: BTreeMap<_, _> = fields.into_iter().collect();
561
562    let mut preds = p_base;
563    for (k, v) in updates {
564        let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
565            field: k.clone(),
566            typ: result_ty.to_string(),
567        })?;
568        let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
569        unifier.unify(&t1, expected_ty)?;
570        preds.extend(p1);
571    }
572    Ok((preds, result_ty))
573}
574
575#[allow(clippy::too_many_arguments)]
576fn infer_record_update_typed_with_hint(
577    unifier: &mut Unifier<'_>,
578    supply: &mut TypeVarSupply,
579    env: &TypeEnv,
580    adts: &BTreeMap<Symbol, AdtDecl>,
581    known: &KnownVariants,
582    base: &Expr,
583    updates: &BTreeMap<Symbol, Arc<Expr>>,
584    hint_ty: &Type,
585) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
586    let (p_base, t_base, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
587    unifier.unify(&t_base, hint_ty)?;
588    let base_ty = unifier.apply_type(&t_base);
589    let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
590    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
591    let (result_ty, fields) = resolve_record_update(
592        unifier,
593        supply,
594        adts,
595        &base_ty,
596        known_variant,
597        &update_fields,
598    )?;
599    let expected: BTreeMap<_, _> = fields.into_iter().collect();
600
601    let mut preds = p_base;
602    let mut typed_updates = BTreeMap::new();
603    for (k, v) in updates {
604        let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
605            field: k.clone(),
606            typ: result_ty.to_string(),
607        })?;
608        let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
609        unifier.unify(&t1, expected_ty)?;
610        preds.extend(p1);
611        typed_updates.insert(k.clone(), typed_v);
612    }
613
614    let typed = TypedExpr::new(
615        result_ty.clone(),
616        TypedExprKind::RecordUpdate {
617            base: Box::new(typed_base),
618            updates: typed_updates,
619        },
620    );
621    Ok((preds, result_ty, typed))
622}
623
624fn infer_expr_type(
625    unifier: &mut Unifier<'_>,
626    supply: &mut TypeVarSupply,
627    env: &TypeEnv,
628    adts: &BTreeMap<Symbol, AdtDecl>,
629    known: &KnownVariants,
630    expr: &Expr,
631) -> Result<(Vec<Predicate>, Type), TypeError> {
632    let span = *expr.span();
633    let res = unifier.with_infer_depth(span, |unifier| {
634        infer_expr_type_inner(unifier, supply, env, adts, known, expr)
635    });
636    res.map_err(|err| err.with_span(&span))
637}
638
639fn infer_expr_type_inner(
640    unifier: &mut Unifier<'_>,
641    supply: &mut TypeVarSupply,
642    env: &TypeEnv,
643    adts: &BTreeMap<Symbol, AdtDecl>,
644    known: &KnownVariants,
645    expr: &Expr,
646) -> Result<(Vec<Predicate>, Type), TypeError> {
647    unifier.charge_infer_node()?;
648    match expr {
649        Expr::Bool(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Bool))),
650        Expr::Uint(_, _) => {
651            let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
652            Ok((vec![Predicate::new("Integral", lit_ty.clone())], lit_ty))
653        }
654        Expr::Int(_, _) => {
655            let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
656            Ok((
657                vec![
658                    Predicate::new("Integral", lit_ty.clone()),
659                    Predicate::new("AdditiveGroup", lit_ty.clone()),
660                ],
661                lit_ty,
662            ))
663        }
664        Expr::Float(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::F32))),
665        Expr::String(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::String))),
666        Expr::Uuid(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Uuid))),
667        Expr::DateTime(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::DateTime))),
668        Expr::Hole(_) => {
669            let t = Type::var(supply.fresh(Some(sym("hole"))));
670            Ok((vec![], t))
671        }
672        Expr::Var(var) => {
673            let schemes = env
674                .lookup(&var.name)
675                .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
676            if schemes.len() == 1 {
677                let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
678                let (preds, t) = instantiate(&scheme, supply);
679                Ok((preds, t))
680            } else {
681                for scheme in schemes {
682                    if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
683                        return Err(TypeError::AmbiguousOverload(var.name.clone()));
684                    }
685                }
686                let t = Type::var(supply.fresh(Some(var.name.clone())));
687                Ok((vec![], t))
688            }
689        }
690        Expr::Lam(..) => {
691            let (params, constraints, body) = collect_lambda_chain(expr);
692            let mut ann_vars = BTreeMap::new();
693            let mut param_tys = Vec::with_capacity(params.len());
694            for (name, ann) in &params {
695                let param_ty = match ann {
696                    Some(ann) => type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?,
697                    None => Type::var(supply.fresh(Some(name.clone()))),
698                };
699                param_tys.push((name.clone(), param_ty));
700            }
701
702            let mut env1 = env.clone();
703            let mut known_body = known.clone();
704            for (name, param_ty) in &param_tys {
705                env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
706                known_body.remove(name);
707            }
708
709            let (mut preds, body_ty) =
710                infer_expr_type(unifier, supply, &env1, adts, &known_body, body)?;
711            let constraint_preds =
712                predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
713            preds.extend(constraint_preds);
714
715            let mut fun_ty = unifier.apply_type(&body_ty);
716            for (_, param_ty) in param_tys.iter().rev() {
717                fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
718            }
719            Ok((preds, fun_ty))
720        }
721        Expr::App(..) => {
722            let (head, args) = collect_app_chain(expr);
723            let (mut preds, mut func_ty) =
724                infer_expr_type(unifier, supply, env, adts, known, head)?;
725            let mut overload_name = None;
726            let mut overload_candidates = if let Expr::Var(var) = head {
727                if let Some(schemes) = env.lookup(&var.name) {
728                    if schemes.len() <= 1 {
729                        None
730                    } else {
731                        let mut candidates = Vec::new();
732                        for scheme in schemes {
733                            if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
734                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
735                            }
736                            let scheme = apply_scheme_with_unifier(scheme, unifier);
737                            let (p, typ) = instantiate(&scheme, supply);
738                            if !p.is_empty() {
739                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
740                            }
741                            candidates.push(typ);
742                        }
743                        overload_name = Some(var.name.clone());
744                        Some(candidates)
745                    }
746                } else {
747                    None
748                }
749            } else {
750                None
751            };
752            for arg in args {
753                let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
754                    TypeKind::Fun(arg, _) => Some(arg.clone()),
755                    _ => None,
756                };
757                let (p_arg, arg_ty) =
758                    infer_app_arg_type(unifier, supply, env, adts, known, arg_hint, arg)?;
759                let arg_ty = unifier.apply_type(&arg_ty);
760                if let Some(candidates) = overload_candidates.take() {
761                    let candidates = candidates
762                        .into_iter()
763                        .map(|t| unifier.apply_type(&t))
764                        .collect::<Vec<_>>();
765                    let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
766                    if narrowed.is_empty()
767                        && let Some(name) = &overload_name
768                    {
769                        return Err(TypeError::AmbiguousOverload(name.clone()));
770                    }
771                    overload_candidates = Some(narrowed);
772                }
773                let res_ty = match overload_candidates.as_ref() {
774                    Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
775                    _ => Type::var(supply.fresh(Some("r".into()))),
776                };
777                unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
778                preds.extend(p_arg);
779                func_ty = match overload_candidates.as_ref() {
780                    Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
781                    _ => unifier.apply_type(&res_ty),
782                };
783            }
784            Ok((preds, func_ty))
785        }
786        Expr::Project(_, base, field) => {
787            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, base)?;
788            let base_ty = unifier.apply_type(&t1);
789            let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
790            let field_ty =
791                resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
792            Ok((p1, field_ty))
793        }
794        Expr::RecordUpdate(_, base, updates) => {
795            let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
796            let base_ty = unifier.apply_type(&t_base);
797            let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
798            let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
799            let (result_ty, fields) = resolve_record_update(
800                unifier,
801                supply,
802                adts,
803                &base_ty,
804                known_variant,
805                &update_fields,
806            )?;
807            let expected: BTreeMap<_, _> = fields.into_iter().collect();
808
809            let mut preds = p_base;
810            for (k, v) in updates {
811                let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
812                    field: k.clone(),
813                    typ: result_ty.to_string(),
814                })?;
815                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
816                unifier.unify(&t1, expected_ty)?;
817                preds.extend(p1);
818            }
819            Ok((preds, result_ty))
820        }
821        Expr::Let(..) => {
822            let mut bindings = Vec::new();
823            let mut cur = expr;
824            while let Expr::Let(_, v, ann, d, b) = cur {
825                bindings.push((v.clone(), ann.clone(), d.clone()));
826                cur = b.as_ref();
827            }
828
829            let mut env_cur = env.clone();
830            let mut known_cur = known.clone();
831            for (v, ann, d) in bindings {
832                let (p1, t1) = if let Some(ref ann_expr) = ann {
833                    let mut ann_vars = BTreeMap::new();
834                    let ann_ty =
835                        type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
836                    match d.as_ref() {
837                        Expr::RecordUpdate(_, base, updates) => infer_record_update_type_with_hint(
838                            unifier,
839                            supply,
840                            &env_cur,
841                            adts,
842                            &known_cur,
843                            base.as_ref(),
844                            updates,
845                            &ann_ty,
846                        )?,
847                        _ => {
848                            let (p1, t1) =
849                                infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?;
850                            unifier.unify(&t1, &ann_ty)?;
851                            (p1, t1)
852                        }
853                    }
854                } else {
855                    infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?
856                };
857                let def_ty = unifier.apply_type(&t1);
858                let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
859                    monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
860                } else {
861                    let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
862                    reject_ambiguous_scheme(&scheme)?;
863                    scheme
864                };
865                env_cur.extend(v.name.clone(), scheme);
866                if let Some(known_variant) =
867                    known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
868                {
869                    known_cur.insert(
870                        v.name.clone(),
871                        KnownVariant {
872                            adt: known_variant.adt,
873                            variant: known_variant.variant,
874                        },
875                    );
876                } else {
877                    known_cur.remove(&v.name);
878                }
879            }
880
881            let (p_body, t_body) =
882                infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, cur)?;
883            Ok((p_body, t_body))
884        }
885        Expr::LetRec(_, bindings, body) => {
886            let mut env_seed = env.clone();
887            let mut known_seed = known.clone();
888            let mut binding_tys = BTreeMap::new();
889            for (var, _ann, _def) in bindings {
890                let tv = Type::var(supply.fresh(Some(var.name.clone())));
891                env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
892                known_seed.remove(&var.name);
893                binding_tys.insert(var.name.clone(), tv);
894            }
895
896            let mut inferred = Vec::with_capacity(bindings.len());
897            for (var, ann, def) in bindings {
898                let (preds, def_ty) =
899                    infer_expr_type(unifier, supply, &env_seed, adts, &known_seed, def)?;
900                if let Some(ann) = ann {
901                    let mut ann_vars = BTreeMap::new();
902                    let ann_ty = type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
903                    unifier.unify(&def_ty, &ann_ty)?;
904                }
905                let binding_ty = binding_tys
906                    .get(&var.name)
907                    .cloned()
908                    .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
909                unifier.unify(&binding_ty, &def_ty)?;
910                let resolved_ty = unifier.apply_type(&binding_ty);
911
912                if let Some(known_variant) =
913                    known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
914                {
915                    known_seed.insert(
916                        var.name.clone(),
917                        KnownVariant {
918                            adt: known_variant.adt,
919                            variant: known_variant.variant,
920                        },
921                    );
922                } else {
923                    known_seed.remove(&var.name);
924                }
925                inferred.push((var.name.clone(), preds, resolved_ty));
926            }
927
928            let mut env_body = env.clone();
929            for (name, preds, def_ty) in inferred {
930                let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
931                reject_ambiguous_scheme(&scheme)?;
932                env_body.extend(name, scheme);
933            }
934
935            let (p_body, t_body) =
936                infer_expr_type(unifier, supply, &env_body, adts, &known_seed, body)?;
937            Ok((p_body, t_body))
938        }
939        Expr::Ite(_, cond, then_expr, else_expr) => {
940            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, cond)?;
941            unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
942            let (p2, t2) = infer_expr_type(unifier, supply, env, adts, known, then_expr)?;
943            let (p3, t3) = infer_expr_type(unifier, supply, env, adts, known, else_expr)?;
944            unifier.unify(&t2, &t3)?;
945            let out_ty = unifier.apply_type(&t2);
946            let mut preds = p1;
947            preds.extend(p2);
948            preds.extend(p3);
949            Ok((preds, out_ty))
950        }
951        Expr::Tuple(_, elems) => {
952            let mut preds = Vec::new();
953            let mut types = Vec::new();
954            for elem in elems {
955                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
956                preds.extend(p1);
957                types.push(unifier.apply_type(&t1));
958            }
959            let tuple_ty = Type::tuple(types);
960            Ok((preds, tuple_ty))
961        }
962        Expr::List(_, elems) => {
963            let elem_tv = Type::var(supply.fresh(Some("a".into())));
964            let mut preds = Vec::new();
965            for elem in elems {
966                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
967                unifier.unify(&t1, &elem_tv)?;
968                preds.extend(p1);
969            }
970            let list_ty = Type::app(
971                Type::builtin(BuiltinTypeId::List),
972                unifier.apply_type(&elem_tv),
973            );
974            Ok((preds, list_ty))
975        }
976        Expr::Dict(_, kvs) => {
977            let elem_tv = Type::var(supply.fresh(Some("v".into())));
978            let mut preds = Vec::new();
979            for v in kvs.values() {
980                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
981                unifier.unify(&t1, &elem_tv)?;
982                preds.extend(p1);
983            }
984            let dict_ty = Type::app(
985                Type::builtin(BuiltinTypeId::Dict),
986                unifier.apply_type(&elem_tv),
987            );
988            Ok((preds, dict_ty))
989        }
990        Expr::Match(_, scrutinee, arms) => {
991            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, scrutinee.as_ref())?;
992            let mut preds = p1;
993            let res_ty = Type::var(supply.fresh(Some("match".into())));
994            let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
995
996            for (pat, expr) in arms {
997                let scrutinee_ty = unifier.apply_type(&t1);
998                let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
999                preds.extend(p_pat);
1000
1001                let mut env_arm = env.clone();
1002                for (name, ty) in binds {
1003                    env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1004                }
1005                let mut known_arm = known.clone();
1006                if let Expr::Var(var) = scrutinee.as_ref() {
1007                    match pat {
1008                        Pattern::Named(_, name, _) => {
1009                            let name_sym = name.to_dotted_symbol();
1010                            if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1011                                known_arm.insert(
1012                                    var.name.clone(),
1013                                    KnownVariant {
1014                                        adt: adt.name.clone(),
1015                                        variant: name_sym,
1016                                    },
1017                                );
1018                            } else {
1019                                known_arm.remove(&var.name);
1020                            }
1021                        }
1022                        _ => {
1023                            known_arm.remove(&var.name);
1024                        }
1025                    }
1026                }
1027                let (p_expr, t_expr) =
1028                    infer_expr_type(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1029                unifier.unify(&res_ty, &t_expr)?;
1030                preds.extend(p_expr);
1031            }
1032
1033            let scrutinee_ty = unifier.apply_type(&t1);
1034            check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1035            let out_ty = unifier.apply_type(&res_ty);
1036            Ok((preds, out_ty))
1037        }
1038        Expr::Ann(_, expr, ann) => {
1039            let ann_ty = type_from_annotation_expr(adts, ann)?;
1040            match expr.as_ref() {
1041                Expr::RecordUpdate(_, base, updates) => {
1042                    let (preds, out_ty) = infer_record_update_type_with_hint(
1043                        unifier,
1044                        supply,
1045                        env,
1046                        adts,
1047                        known,
1048                        base.as_ref(),
1049                        updates,
1050                        &ann_ty,
1051                    )?;
1052                    Ok((preds, out_ty))
1053                }
1054                _ => {
1055                    let (preds, expr_ty) =
1056                        infer_expr_type(unifier, supply, env, adts, known, expr)?;
1057                    unifier.unify(&expr_ty, &ann_ty)?;
1058                    let out_ty = unifier.apply_type(&ann_ty);
1059                    Ok((preds, out_ty))
1060                }
1061            }
1062        }
1063    }
1064}
1065
1066fn infer_expr(
1067    unifier: &mut Unifier<'_>,
1068    supply: &mut TypeVarSupply,
1069    env: &TypeEnv,
1070    adts: &BTreeMap<Symbol, AdtDecl>,
1071    known: &KnownVariants,
1072    expr: &Expr,
1073) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
1074    let span = *expr.span();
1075    let res = unifier.with_infer_depth(span, |unifier| {
1076        (|| {
1077            unifier.charge_infer_node()?;
1078            match expr {
1079                Expr::Bool(_, v) => {
1080                    let t = Type::builtin(BuiltinTypeId::Bool);
1081                    Ok((
1082                        vec![],
1083                        t.clone(),
1084                        TypedExpr::new(t, TypedExprKind::Bool(*v)),
1085                    ))
1086                }
1087                Expr::Uint(_, v) => {
1088                    let t = Type::var(supply.fresh(Some(sym("n"))));
1089                    Ok((
1090                        vec![Predicate::new("Integral", t.clone())],
1091                        t.clone(),
1092                        TypedExpr::new(t, TypedExprKind::Uint(*v)),
1093                    ))
1094                }
1095                Expr::Int(_, v) => {
1096                    let t = Type::var(supply.fresh(Some(sym("n"))));
1097                    Ok((
1098                        vec![
1099                            Predicate::new("Integral", t.clone()),
1100                            Predicate::new("AdditiveGroup", t.clone()),
1101                        ],
1102                        t.clone(),
1103                        TypedExpr::new(t, TypedExprKind::Int(*v)),
1104                    ))
1105                }
1106                Expr::Float(_, v) => {
1107                    let t = Type::builtin(BuiltinTypeId::F32);
1108                    Ok((
1109                        vec![],
1110                        t.clone(),
1111                        TypedExpr::new(t, TypedExprKind::Float(*v)),
1112                    ))
1113                }
1114                Expr::String(_, v) => {
1115                    let t = Type::builtin(BuiltinTypeId::String);
1116                    Ok((
1117                        vec![],
1118                        t.clone(),
1119                        TypedExpr::new(t, TypedExprKind::String(v.clone())),
1120                    ))
1121                }
1122                Expr::Uuid(_, v) => {
1123                    let t = Type::builtin(BuiltinTypeId::Uuid);
1124                    Ok((
1125                        vec![],
1126                        t.clone(),
1127                        TypedExpr::new(t, TypedExprKind::Uuid(*v)),
1128                    ))
1129                }
1130                Expr::DateTime(_, v) => {
1131                    let t = Type::builtin(BuiltinTypeId::DateTime);
1132                    Ok((
1133                        vec![],
1134                        t.clone(),
1135                        TypedExpr::new(t, TypedExprKind::DateTime(*v)),
1136                    ))
1137                }
1138                Expr::Hole(_) => {
1139                    let t = Type::var(supply.fresh(Some(sym("hole"))));
1140                    Ok((vec![], t.clone(), TypedExpr::new(t, TypedExprKind::Hole)))
1141                }
1142                Expr::Var(var) => {
1143                    let schemes = env
1144                        .lookup(&var.name)
1145                        .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1146                    if schemes.len() == 1 {
1147                        let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1148                        let (preds, t) = instantiate(&scheme, supply);
1149                        let typed = TypedExpr::new(
1150                            t.clone(),
1151                            TypedExprKind::Var {
1152                                name: var.name.clone(),
1153                                overloads: vec![],
1154                            },
1155                        );
1156                        Ok((preds, t, typed))
1157                    } else {
1158                        let mut overloads = Vec::new();
1159                        for scheme in schemes {
1160                            if !scheme.preds.is_empty() {
1161                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
1162                            }
1163
1164                            let scheme = apply_scheme_with_unifier(scheme, unifier);
1165                            let (preds, typ) = instantiate(&scheme, supply);
1166                            if !preds.is_empty() {
1167                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
1168                            }
1169                            overloads.push(typ);
1170                        }
1171                        let t = Type::var(supply.fresh(Some(var.name.clone())));
1172                        let typed = TypedExpr::new(
1173                            t.clone(),
1174                            TypedExprKind::Var {
1175                                name: var.name.clone(),
1176                                overloads,
1177                            },
1178                        );
1179                        Ok((vec![], t, typed))
1180                    }
1181                }
1182                Expr::Lam(..) => {
1183                    let (params, constraints, body) = collect_lambda_chain(expr);
1184                    let mut ann_vars = BTreeMap::new();
1185                    let mut param_tys = Vec::with_capacity(params.len());
1186                    for (name, ann) in &params {
1187                        let param_ty = match ann {
1188                            Some(ann) => {
1189                                type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?
1190                            }
1191                            None => Type::var(supply.fresh(Some(name.clone()))),
1192                        };
1193                        param_tys.push((name.clone(), param_ty));
1194                    }
1195
1196                    let mut env1 = env.clone();
1197                    let mut known_body = known.clone();
1198                    for (name, param_ty) in &param_tys {
1199                        env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
1200                        known_body.remove(name);
1201                    }
1202
1203                    let (mut preds, body_ty, typed_body) =
1204                        infer_expr(unifier, supply, &env1, adts, &known_body, body)?;
1205                    let constraint_preds =
1206                        predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
1207                    preds.extend(constraint_preds);
1208
1209                    let mut typed = typed_body;
1210                    let mut fun_ty = unifier.apply_type(&body_ty);
1211                    for (name, param_ty) in param_tys.iter().rev() {
1212                        fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
1213                        typed = TypedExpr::new(
1214                            fun_ty.clone(),
1215                            TypedExprKind::Lam {
1216                                param: name.clone(),
1217                                body: Box::new(typed),
1218                            },
1219                        );
1220                    }
1221
1222                    Ok((preds, fun_ty, typed))
1223                }
1224                Expr::App(..) => {
1225                    let (head, args) = collect_app_chain(expr);
1226                    let (mut preds, mut func_ty, mut typed) =
1227                        infer_expr(unifier, supply, env, adts, known, head)?;
1228                    let mut overload_name = None;
1229                    let mut overload_candidates = match &typed.kind {
1230                        TypedExprKind::Var { name, overloads } if !overloads.is_empty() => {
1231                            overload_name = Some(name.clone());
1232                            Some(overloads.clone())
1233                        }
1234                        _ => None,
1235                    };
1236                    for arg in args {
1237                        let expected_arg = match unifier.apply_type(&func_ty).as_ref() {
1238                            TypeKind::Fun(arg, _) => Some(arg.clone()),
1239                            _ => None,
1240                        };
1241                        let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
1242                            TypeKind::Fun(arg, _) => Some(arg.clone()),
1243                            _ => None,
1244                        };
1245                        let (p_arg, arg_ty, typed_arg) =
1246                            infer_app_arg_typed(unifier, supply, env, adts, known, arg_hint, arg)?;
1247                        let mut arg_ty = unifier.apply_type(&arg_ty);
1248                        let mut typed_arg = typed_arg;
1249
1250                        if let Some(expected_arg) = expected_arg {
1251                            let expected_arg = unifier.apply_type(&expected_arg);
1252                            if let (Some(expected_elem), Some(arg_elem)) = (
1253                                unary_app_arg(&expected_arg, "Array"),
1254                                unary_app_arg(&arg_ty, "List"),
1255                            ) {
1256                                unifier.unify(&expected_elem, &arg_elem)?;
1257                                let elem_ty = unifier.apply_type(&expected_elem);
1258                                let list_ty = Type::list(elem_ty.clone());
1259                                let array_ty = Type::array(elem_ty);
1260                                let coercion_ty = Type::fun(list_ty, array_ty.clone());
1261                                let coercion_fn = TypedExpr::new(
1262                                    coercion_ty,
1263                                    TypedExprKind::Var {
1264                                        name: sym("prim_array_from_list"),
1265                                        overloads: vec![],
1266                                    },
1267                                );
1268                                typed_arg = TypedExpr::new(
1269                                    array_ty.clone(),
1270                                    TypedExprKind::App(Box::new(coercion_fn), Box::new(typed_arg)),
1271                                );
1272                                arg_ty = array_ty;
1273                            }
1274                        }
1275                        if let Some(candidates) = overload_candidates.take() {
1276                            let candidates = candidates
1277                                .into_iter()
1278                                .map(|t| unifier.apply_type(&t))
1279                                .collect::<Vec<_>>();
1280                            let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
1281                            if narrowed.is_empty()
1282                                && let Some(name) = &overload_name
1283                            {
1284                                return Err(TypeError::AmbiguousOverload(name.clone()));
1285                            }
1286                            overload_candidates = Some(narrowed);
1287                        }
1288                        let res_ty = match overload_candidates.as_ref() {
1289                            Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
1290                            _ => Type::var(supply.fresh(Some("r".into()))),
1291                        };
1292                        unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
1293                        let result_ty = match overload_candidates.as_ref() {
1294                            Some(candidates) if candidates.len() == 1 => {
1295                                unifier.apply_type(&candidates[0])
1296                            }
1297                            _ => unifier.apply_type(&res_ty),
1298                        };
1299                        preds.extend(p_arg);
1300                        typed = TypedExpr::new(
1301                            result_ty.clone(),
1302                            TypedExprKind::App(Box::new(typed), Box::new(typed_arg)),
1303                        );
1304                        func_ty = result_ty;
1305                    }
1306                    Ok((preds, func_ty, typed))
1307                }
1308                Expr::Project(_, base, field) => {
1309                    let (p1, t1, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
1310                    let base_ty = unifier.apply_type(&t1);
1311                    let known_variant =
1312                        known_variant_from_expr_with_known(base, &base_ty, adts, known);
1313                    let field_ty =
1314                        resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
1315                    let typed = TypedExpr::new(
1316                        field_ty.clone(),
1317                        TypedExprKind::Project {
1318                            expr: Box::new(typed_base),
1319                            field: field.clone(),
1320                        },
1321                    );
1322                    Ok((p1, field_ty, typed))
1323                }
1324                Expr::RecordUpdate(_, base, updates) => {
1325                    let (p_base, t_base, typed_base) =
1326                        infer_expr(unifier, supply, env, adts, known, base)?;
1327                    let base_ty = unifier.apply_type(&t_base);
1328                    let known_variant =
1329                        known_variant_from_expr_with_known(base, &base_ty, adts, known);
1330                    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
1331                    let (result_ty, fields) = resolve_record_update(
1332                        unifier,
1333                        supply,
1334                        adts,
1335                        &base_ty,
1336                        known_variant,
1337                        &update_fields,
1338                    )?;
1339                    let expected: BTreeMap<_, _> = fields.into_iter().collect();
1340
1341                    let mut preds = p_base;
1342                    let mut typed_updates = BTreeMap::new();
1343                    for (k, v) in updates {
1344                        let expected_ty =
1345                            expected.get(k).ok_or_else(|| TypeError::UnknownField {
1346                                field: k.clone(),
1347                                typ: result_ty.to_string(),
1348                            })?;
1349                        let (p1, t1, typed_v) =
1350                            infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
1351                        unifier.unify(&t1, expected_ty)?;
1352                        preds.extend(p1);
1353                        typed_updates.insert(k.clone(), typed_v);
1354                    }
1355                    let typed = TypedExpr::new(
1356                        result_ty.clone(),
1357                        TypedExprKind::RecordUpdate {
1358                            base: Box::new(typed_base),
1359                            updates: typed_updates,
1360                        },
1361                    );
1362                    Ok((preds, result_ty, typed))
1363                }
1364                Expr::Let(..) => {
1365                    let mut bindings = Vec::new();
1366                    let mut cur = expr;
1367                    while let Expr::Let(_, v, ann, d, b) = cur {
1368                        bindings.push((v.clone(), ann.clone(), d.clone()));
1369                        cur = b.as_ref();
1370                    }
1371
1372                    let mut env_cur = env.clone();
1373                    let mut known_cur = known.clone();
1374                    let mut typed_defs = Vec::new();
1375                    for (v, ann, d) in bindings {
1376                        let (p1, t1, typed_def) = if let Some(ref ann_expr) = ann {
1377                            let mut ann_vars = BTreeMap::new();
1378                            let ann_ty = type_from_annotation_expr_vars(
1379                                adts,
1380                                ann_expr,
1381                                &mut ann_vars,
1382                                supply,
1383                            )?;
1384                            match d.as_ref() {
1385                                Expr::RecordUpdate(_, base, updates) => {
1386                                    infer_record_update_typed_with_hint(
1387                                        unifier,
1388                                        supply,
1389                                        &env_cur,
1390                                        adts,
1391                                        &known_cur,
1392                                        base.as_ref(),
1393                                        updates,
1394                                        &ann_ty,
1395                                    )?
1396                                }
1397                                _ => {
1398                                    let (p1, t1, typed_def) = infer_expr(
1399                                        unifier, supply, &env_cur, adts, &known_cur, &d,
1400                                    )?;
1401                                    unifier.unify(&t1, &ann_ty)?;
1402                                    (p1, t1, typed_def)
1403                                }
1404                            }
1405                        } else {
1406                            infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?
1407                        };
1408                        let def_ty = unifier.apply_type(&t1);
1409                        let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
1410                            monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
1411                        } else {
1412                            let scheme =
1413                                generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
1414                            reject_ambiguous_scheme(&scheme)?;
1415                            scheme
1416                        };
1417                        env_cur.extend(v.name.clone(), scheme);
1418                        if let Some(known_variant) =
1419                            known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
1420                        {
1421                            known_cur.insert(
1422                                v.name.clone(),
1423                                KnownVariant {
1424                                    adt: known_variant.adt,
1425                                    variant: known_variant.variant,
1426                                },
1427                            );
1428                        } else {
1429                            known_cur.remove(&v.name);
1430                        }
1431                        typed_defs.push((v.name.clone(), typed_def));
1432                    }
1433
1434                    let (p_body, t_body, typed_body) =
1435                        infer_expr(unifier, supply, &env_cur, adts, &known_cur, cur)?;
1436
1437                    let mut typed = typed_body;
1438                    for (name, def) in typed_defs.into_iter().rev() {
1439                        typed = TypedExpr::new(
1440                            t_body.clone(),
1441                            TypedExprKind::Let {
1442                                name,
1443                                def: Box::new(def),
1444                                body: Box::new(typed),
1445                            },
1446                        );
1447                    }
1448                    Ok((p_body, t_body, typed))
1449                }
1450                Expr::LetRec(_, bindings, body) => {
1451                    let mut env_seed = env.clone();
1452                    let mut known_seed = known.clone();
1453                    let mut binding_tys = BTreeMap::new();
1454                    for (var, _ann, _def) in bindings {
1455                        let tv = Type::var(supply.fresh(Some(var.name.clone())));
1456                        env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
1457                        known_seed.remove(&var.name);
1458                        binding_tys.insert(var.name.clone(), tv);
1459                    }
1460
1461                    let mut inferred_defs = Vec::with_capacity(bindings.len());
1462                    for (var, ann, def) in bindings {
1463                        let (preds, def_ty, typed_def) =
1464                            infer_expr(unifier, supply, &env_seed, adts, &known_seed, def)?;
1465                        if let Some(ann) = ann {
1466                            let mut ann_vars = BTreeMap::new();
1467                            let ann_ty =
1468                                type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
1469                            unifier.unify(&def_ty, &ann_ty)?;
1470                        }
1471                        let binding_ty = binding_tys
1472                            .get(&var.name)
1473                            .cloned()
1474                            .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1475                        unifier.unify(&binding_ty, &def_ty)?;
1476                        let resolved_ty = unifier.apply_type(&binding_ty);
1477
1478                        if let Some(known_variant) =
1479                            known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
1480                        {
1481                            known_seed.insert(
1482                                var.name.clone(),
1483                                KnownVariant {
1484                                    adt: known_variant.adt,
1485                                    variant: known_variant.variant,
1486                                },
1487                            );
1488                        } else {
1489                            known_seed.remove(&var.name);
1490                        }
1491                        inferred_defs.push((var.name.clone(), preds, resolved_ty, typed_def));
1492                    }
1493
1494                    let mut env_body = env.clone();
1495                    let mut typed_bindings = Vec::with_capacity(inferred_defs.len());
1496                    for (name, preds, def_ty, typed_def) in inferred_defs {
1497                        let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
1498                        reject_ambiguous_scheme(&scheme)?;
1499                        env_body.extend(name.clone(), scheme);
1500                        typed_bindings.push((name, typed_def));
1501                    }
1502
1503                    let (p_body, t_body, typed_body) =
1504                        infer_expr(unifier, supply, &env_body, adts, &known_seed, body)?;
1505                    let typed = TypedExpr::new(
1506                        t_body.clone(),
1507                        TypedExprKind::LetRec {
1508                            bindings: typed_bindings,
1509                            body: Box::new(typed_body),
1510                        },
1511                    );
1512                    Ok((p_body, t_body, typed))
1513                }
1514                Expr::Ite(_, cond, then_expr, else_expr) => {
1515                    let (p1, t1, typed_cond) = infer_expr(unifier, supply, env, adts, known, cond)?;
1516                    unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
1517                    let (p2, t2, typed_then) =
1518                        infer_expr(unifier, supply, env, adts, known, then_expr)?;
1519                    let (p3, t3, typed_else) =
1520                        infer_expr(unifier, supply, env, adts, known, else_expr)?;
1521                    unifier.unify(&t2, &t3)?;
1522                    let out_ty = unifier.apply_type(&t2);
1523                    let mut preds = p1;
1524                    preds.extend(p2);
1525                    preds.extend(p3);
1526                    let typed = TypedExpr::new(
1527                        out_ty.clone(),
1528                        TypedExprKind::Ite {
1529                            cond: Box::new(typed_cond),
1530                            then_expr: Box::new(typed_then),
1531                            else_expr: Box::new(typed_else),
1532                        },
1533                    );
1534                    Ok((preds, out_ty, typed))
1535                }
1536                Expr::Tuple(_, elems) => {
1537                    let mut preds = Vec::new();
1538                    let mut types = Vec::new();
1539                    let mut typed_elems = Vec::new();
1540                    for elem in elems {
1541                        let (p1, t1, typed_elem) =
1542                            infer_expr(unifier, supply, env, adts, known, elem)?;
1543                        preds.extend(p1);
1544                        types.push(unifier.apply_type(&t1));
1545                        typed_elems.push(typed_elem);
1546                    }
1547                    let tuple_ty = Type::tuple(types);
1548                    let typed = TypedExpr::new(tuple_ty.clone(), TypedExprKind::Tuple(typed_elems));
1549                    Ok((preds, tuple_ty, typed))
1550                }
1551                Expr::List(_, elems) => {
1552                    let elem_tv = Type::var(supply.fresh(Some("a".into())));
1553                    let mut preds = Vec::new();
1554                    let mut typed_elems = Vec::new();
1555                    for elem in elems {
1556                        let (p1, t1, typed_elem) =
1557                            infer_expr(unifier, supply, env, adts, known, elem)?;
1558                        unifier.unify(&t1, &elem_tv)?;
1559                        preds.extend(p1);
1560                        typed_elems.push(typed_elem);
1561                    }
1562                    let list_ty = Type::app(
1563                        Type::builtin(BuiltinTypeId::List),
1564                        unifier.apply_type(&elem_tv),
1565                    );
1566                    let typed = TypedExpr::new(list_ty.clone(), TypedExprKind::List(typed_elems));
1567                    Ok((preds, list_ty, typed))
1568                }
1569                Expr::Dict(_, kvs) => {
1570                    let elem_tv = Type::var(supply.fresh(Some("v".into())));
1571                    let mut preds = Vec::new();
1572                    let mut typed_kvs = BTreeMap::new();
1573                    for (k, v) in kvs {
1574                        let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v)?;
1575                        unifier.unify(&t1, &elem_tv)?;
1576                        preds.extend(p1);
1577                        typed_kvs.insert(k.clone(), typed_v);
1578                    }
1579                    let dict_ty = Type::app(
1580                        Type::builtin(BuiltinTypeId::Dict),
1581                        unifier.apply_type(&elem_tv),
1582                    );
1583                    let typed = TypedExpr::new(dict_ty.clone(), TypedExprKind::Dict(typed_kvs));
1584                    Ok((preds, dict_ty, typed))
1585                }
1586                Expr::Match(_, scrutinee, arms) => {
1587                    let (p1, t1, typed_scrutinee) =
1588                        infer_expr(unifier, supply, env, adts, known, scrutinee)?;
1589                    let mut preds = p1;
1590                    let mut typed_arms = Vec::new();
1591                    let res_ty = Type::var(supply.fresh(Some("match".into())));
1592                    let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
1593
1594                    for (pat, expr) in arms {
1595                        let scrutinee_ty = unifier.apply_type(&t1);
1596                        let (p_pat, binds) =
1597                            infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
1598                        preds.extend(p_pat);
1599
1600                        let mut env_arm = env.clone();
1601                        for (name, ty) in binds {
1602                            env_arm
1603                                .extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1604                        }
1605                        let mut known_arm = known.clone();
1606                        if let Expr::Var(var) = scrutinee.as_ref() {
1607                            match pat {
1608                                Pattern::Named(_, name, _) => {
1609                                    let name_sym = name.to_dotted_symbol();
1610                                    if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1611                                        known_arm.insert(
1612                                            var.name.clone(),
1613                                            KnownVariant {
1614                                                adt: adt.name.clone(),
1615                                                variant: name_sym,
1616                                            },
1617                                        );
1618                                    } else {
1619                                        known_arm.remove(&var.name);
1620                                    }
1621                                }
1622                                _ => {
1623                                    known_arm.remove(&var.name);
1624                                }
1625                            }
1626                        }
1627                        let (p_expr, t_expr, typed_expr) =
1628                            infer_expr(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1629                        unifier.unify(&res_ty, &t_expr)?;
1630                        preds.extend(p_expr);
1631                        typed_arms.push((pat.clone(), typed_expr));
1632                    }
1633
1634                    let scrutinee_ty = unifier.apply_type(&t1);
1635                    check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1636                    let out_ty = unifier.apply_type(&res_ty);
1637                    let typed = TypedExpr::new(
1638                        out_ty.clone(),
1639                        TypedExprKind::Match {
1640                            scrutinee: Box::new(typed_scrutinee),
1641                            arms: typed_arms,
1642                        },
1643                    );
1644                    Ok((preds, out_ty, typed))
1645                }
1646                Expr::Ann(_, expr, ann) => {
1647                    let ann_ty = type_from_annotation_expr(adts, ann)?;
1648                    match expr.as_ref() {
1649                        Expr::RecordUpdate(_, base, updates) => {
1650                            infer_record_update_typed_with_hint(
1651                                unifier,
1652                                supply,
1653                                env,
1654                                adts,
1655                                known,
1656                                base.as_ref(),
1657                                updates,
1658                                &ann_ty,
1659                            )
1660                        }
1661                        _ => {
1662                            let (preds, expr_ty, typed_expr) =
1663                                infer_expr(unifier, supply, env, adts, known, expr)?;
1664                            unifier.unify(&expr_ty, &ann_ty)?;
1665                            let out_ty = unifier.apply_type(&ann_ty);
1666                            Ok((preds, out_ty, typed_expr))
1667                        }
1668                    }
1669                }
1670            }
1671        })()
1672    });
1673    res.map_err(|err| err.with_span(&span))
1674}
1675
1676fn ctor_lookup<'a>(
1677    adts: &'a BTreeMap<Symbol, AdtDecl>,
1678    name: &Symbol,
1679) -> Option<(&'a AdtDecl, &'a AdtVariant)> {
1680    let mut found = None;
1681    for adt in adts.values() {
1682        if let Some(variant) = adt.variants.iter().find(|v| &v.name == name) {
1683            if found.is_some() {
1684                return None;
1685            }
1686            found = Some((adt, variant));
1687        }
1688    }
1689    found
1690}
1691
1692fn record_fields(variant: &AdtVariant) -> Option<&[(Symbol, Type)]> {
1693    if variant.args.len() != 1 {
1694        return None;
1695    }
1696    match variant.args[0].as_ref() {
1697        TypeKind::Record(fields) => Some(fields),
1698        _ => None,
1699    }
1700}
1701
1702fn instantiate_variant_fields(
1703    adt: &AdtDecl,
1704    variant: &AdtVariant,
1705    supply: &mut TypeVarSupply,
1706) -> Option<(Type, Vec<(Symbol, Type)>)> {
1707    let fields = record_fields(variant)?;
1708    let mut subst = Subst::new_sync();
1709    for param in &adt.params {
1710        let fresh = Type::var(supply.fresh(param.var.name.clone()));
1711        subst = subst.insert(param.var.id, fresh);
1712    }
1713    let result_ty = adt.result_type().apply(&subst);
1714    let fields = fields
1715        .iter()
1716        .map(|(name, ty)| (name.clone(), ty.apply(&subst)))
1717        .collect();
1718    Some((result_ty, fields))
1719}
1720
1721fn known_variant_from_expr(
1722    expr: &Expr,
1723    expr_ty: &Type,
1724    adts: &BTreeMap<Symbol, AdtDecl>,
1725) -> Option<KnownVariant> {
1726    let mut expr = expr;
1727    while let Expr::Ann(_, inner, _) = expr {
1728        expr = inner.as_ref();
1729    }
1730    if matches!(expr_ty.as_ref(), TypeKind::Fun(..)) {
1731        return None;
1732    }
1733    let ctor = match expr {
1734        Expr::App(_, f, _) => match f.as_ref() {
1735            Expr::Var(var) => var.name.clone(),
1736            _ => return None,
1737        },
1738        _ => return None,
1739    };
1740    let (adt, variant) = ctor_lookup(adts, &ctor)?;
1741    record_fields(variant)?;
1742    Some(KnownVariant {
1743        adt: adt.name.clone(),
1744        variant: variant.name.clone(),
1745    })
1746}
1747
1748fn known_variant_from_expr_with_known(
1749    expr: &Expr,
1750    expr_ty: &Type,
1751    adts: &BTreeMap<Symbol, AdtDecl>,
1752    known: &KnownVariants,
1753) -> Option<KnownVariant> {
1754    let mut expr = expr;
1755    while let Expr::Ann(_, inner, _) = expr {
1756        expr = inner.as_ref();
1757    }
1758    match expr {
1759        Expr::Var(var) => known.get(&var.name).cloned(),
1760        Expr::RecordUpdate(_, base, _) => {
1761            known_variant_from_expr_with_known(base.as_ref(), expr_ty, adts, known)
1762        }
1763        _ => known_variant_from_expr(expr, expr_ty, adts),
1764    }
1765}
1766
1767fn select_record_variant<'a, F>(
1768    adts: &'a BTreeMap<Symbol, AdtDecl>,
1769    base_ty: &Type,
1770    known_variant: Option<KnownVariant>,
1771    field_for_errors: &Symbol,
1772    matches_fields: F,
1773) -> Result<(&'a AdtDecl, &'a AdtVariant), TypeError>
1774where
1775    F: Fn(&[(Symbol, Type)]) -> bool,
1776{
1777    if let Some(info) = known_variant {
1778        let adt = adts
1779            .get(&info.adt)
1780            .ok_or_else(|| TypeError::UnknownTypeName(info.adt.clone()))?;
1781        let variant = adt
1782            .variants
1783            .iter()
1784            .find(|v| v.name == info.variant)
1785            .ok_or_else(|| TypeError::UnknownField {
1786                field: field_for_errors.clone(),
1787                typ: base_ty.to_string(),
1788            })?;
1789        return Ok((adt, variant));
1790    }
1791
1792    if let Some(adt_name) = type_head_name(base_ty) {
1793        let adt = adts.get(adt_name).ok_or_else(|| TypeError::UnknownField {
1794            field: field_for_errors.clone(),
1795            typ: base_ty.to_string(),
1796        })?;
1797        if adt.variants.len() == 1 {
1798            return Ok((adt, &adt.variants[0]));
1799        }
1800        return Err(TypeError::FieldNotKnown {
1801            field: field_for_errors.clone(),
1802            typ: base_ty.to_string(),
1803        });
1804    }
1805
1806    if matches!(base_ty.as_ref(), TypeKind::Var(_)) {
1807        let mut candidates = Vec::new();
1808        for adt in adts.values() {
1809            if adt.variants.len() != 1 {
1810                continue;
1811            }
1812            let variant = &adt.variants[0];
1813            let Some(fields) = record_fields(variant) else {
1814                continue;
1815            };
1816            if matches_fields(fields) {
1817                candidates.push((adt, variant));
1818            }
1819        }
1820        if candidates.len() == 1 {
1821            return Ok(candidates.remove(0));
1822        }
1823        if candidates.is_empty() {
1824            return Err(TypeError::UnknownField {
1825                field: field_for_errors.clone(),
1826                typ: base_ty.to_string(),
1827            });
1828        }
1829        return Err(TypeError::FieldNotKnown {
1830            field: field_for_errors.clone(),
1831            typ: base_ty.to_string(),
1832        });
1833    }
1834
1835    Err(TypeError::UnknownField {
1836        field: field_for_errors.clone(),
1837        typ: base_ty.to_string(),
1838    })
1839}
1840
1841fn resolve_record_update(
1842    unifier: &mut Unifier<'_>,
1843    supply: &mut TypeVarSupply,
1844    adts: &BTreeMap<Symbol, AdtDecl>,
1845    base_ty: &Type,
1846    known_variant: Option<KnownVariant>,
1847    update_fields: &[Symbol],
1848) -> Result<(Type, Vec<(Symbol, Type)>), TypeError> {
1849    if let TypeKind::Record(fields) = base_ty.as_ref() {
1850        return Ok((base_ty.clone(), fields.clone()));
1851    }
1852
1853    let field_for_errors = update_fields.first().cloned().unwrap_or_else(|| sym("_"));
1854
1855    let (adt, variant) =
1856        select_record_variant(adts, base_ty, known_variant, &field_for_errors, |fields| {
1857            update_fields
1858                .iter()
1859                .all(|field| fields.iter().any(|(name, _)| name == field))
1860        })?;
1861
1862    let (result_ty, fields) =
1863        instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1864            TypeError::UnknownField {
1865                field: field_for_errors.clone(),
1866                typ: base_ty.to_string(),
1867            }
1868        })?;
1869
1870    for field in update_fields {
1871        if fields.iter().all(|(name, _)| name != field) {
1872            return Err(TypeError::UnknownField {
1873                field: field.clone(),
1874                typ: base_ty.to_string(),
1875            });
1876        }
1877    }
1878
1879    unifier.unify(base_ty, &result_ty)?;
1880    let result_ty = unifier.apply_type(&result_ty);
1881    let fields = fields
1882        .into_iter()
1883        .map(|(name, ty)| (name, unifier.apply_type(&ty)))
1884        .collect();
1885    Ok((result_ty, fields))
1886}
1887
1888fn resolve_projection(
1889    unifier: &mut Unifier<'_>,
1890    supply: &mut TypeVarSupply,
1891    adts: &BTreeMap<Symbol, AdtDecl>,
1892    base_ty: &Type,
1893    known_variant: Option<KnownVariant>,
1894    field: &Symbol,
1895) -> Result<Type, TypeError> {
1896    if let Ok(index) = field.as_ref().parse::<usize>() {
1897        let elem_ty = match base_ty.as_ref() {
1898            TypeKind::Tuple(elems) => {
1899                elems
1900                    .get(index)
1901                    .cloned()
1902                    .ok_or_else(|| TypeError::UnknownField {
1903                        field: field.clone(),
1904                        typ: base_ty.to_string(),
1905                    })?
1906            }
1907            TypeKind::Var(_) => {
1908                let mut elems = Vec::with_capacity(index + 1);
1909                for _ in 0..=index {
1910                    elems.push(Type::var(supply.fresh(Some(sym("t")))));
1911                }
1912                let tuple_ty = Type::tuple(elems.clone());
1913                unifier.unify(base_ty, &tuple_ty)?;
1914                elems[index].clone()
1915            }
1916            _ => {
1917                return Err(TypeError::UnknownField {
1918                    field: field.clone(),
1919                    typ: base_ty.to_string(),
1920                });
1921            }
1922        };
1923        return Ok(unifier.apply_type(&elem_ty));
1924    }
1925
1926    let (adt, variant) = select_record_variant(adts, base_ty, known_variant, field, |fields| {
1927        fields.iter().any(|(name, _)| name == field)
1928    })?;
1929
1930    let (result_ty, fields) =
1931        instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1932            TypeError::UnknownField {
1933                field: field.clone(),
1934                typ: base_ty.to_string(),
1935            }
1936        })?;
1937    let field_ty = fields
1938        .iter()
1939        .find(|(name, _)| name == field)
1940        .map(|(_, ty)| ty.clone())
1941        .ok_or_else(|| TypeError::UnknownField {
1942            field: field.clone(),
1943            typ: base_ty.to_string(),
1944        })?;
1945    unifier.unify(base_ty, &result_ty)?;
1946    Ok(unifier.apply_type(&field_ty))
1947}
1948
1949fn decompose_fun(typ: &Type, arity: usize) -> Option<(Vec<Type>, Type)> {
1950    let mut args = Vec::with_capacity(arity);
1951    let mut cur = typ.clone();
1952    for _ in 0..arity {
1953        match cur.as_ref() {
1954            TypeKind::Fun(a, b) => {
1955                args.push(a.clone());
1956                cur = b.clone();
1957            }
1958            _ => return None,
1959        }
1960    }
1961    Some((args, cur))
1962}
1963
1964type InferPatternResult = (Vec<Predicate>, Vec<(Symbol, Type)>);
1965
1966fn infer_pattern(
1967    unifier: &mut Unifier<'_>,
1968    supply: &mut TypeVarSupply,
1969    env: &TypeEnv,
1970    pat: &Pattern,
1971    scrutinee_ty: &Type,
1972) -> Result<InferPatternResult, TypeError> {
1973    let span = *pat.span();
1974    let res = (|| {
1975        unifier.charge_infer_node()?;
1976        match pat {
1977            Pattern::Wildcard(..) => Ok((vec![], vec![])),
1978            Pattern::Var(var) => Ok((
1979                vec![],
1980                vec![(var.name.clone(), unifier.apply_type(scrutinee_ty))],
1981            )),
1982            Pattern::Named(_, name, ps) => {
1983                let ctor_name = name.to_dotted_symbol();
1984                let schemes = env
1985                    .lookup(&ctor_name)
1986                    .ok_or_else(|| TypeError::UnknownVar(ctor_name.clone()))?;
1987                if schemes.len() != 1 {
1988                    return Err(TypeError::AmbiguousOverload(ctor_name));
1989                }
1990                let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1991                let (preds, ctor_ty) = instantiate(&scheme, supply);
1992                let (arg_tys, res_ty) = decompose_fun(&ctor_ty, ps.len())
1993                    .ok_or(TypeError::UnsupportedExpr("pattern constructor"))?;
1994                unifier.unify(&res_ty, scrutinee_ty)?;
1995                let mut all_preds = preds;
1996                let mut bindings = Vec::new();
1997                for (p, arg_ty) in ps.iter().zip(arg_tys.iter()) {
1998                    let arg_ty = unifier.apply_type(arg_ty);
1999                    let (p1, binds1) = infer_pattern(unifier, supply, env, p, &arg_ty)?;
2000                    all_preds.extend(p1);
2001                    bindings.extend(binds1);
2002                }
2003                let bindings = bindings
2004                    .into_iter()
2005                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2006                    .collect();
2007                Ok((all_preds, bindings))
2008            }
2009            Pattern::List(_, ps) => {
2010                let elem_tv = Type::var(supply.fresh(Some("a".into())));
2011                let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2012                unifier.unify(scrutinee_ty, &list_ty)?;
2013                let mut preds = Vec::new();
2014                let mut bindings = Vec::new();
2015                for p in ps {
2016                    let elem_ty = unifier.apply_type(&elem_tv);
2017                    let (p1, binds1) = infer_pattern(unifier, supply, env, p, &elem_ty)?;
2018                    preds.extend(p1);
2019                    bindings.extend(binds1);
2020                }
2021                let bindings = bindings
2022                    .into_iter()
2023                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2024                    .collect();
2025                Ok((preds, bindings))
2026            }
2027            Pattern::Cons(_, head, tail) => {
2028                let elem_tv = Type::var(supply.fresh(Some("a".into())));
2029                let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2030                unifier.unify(scrutinee_ty, &list_ty)?;
2031                let mut preds = Vec::new();
2032                let mut bindings = Vec::new();
2033
2034                let head_ty = unifier.apply_type(&elem_tv);
2035                let (p1, binds1) = infer_pattern(unifier, supply, env, head, &head_ty)?;
2036                preds.extend(p1);
2037                bindings.extend(binds1);
2038
2039                let tail_ty = Type::app(
2040                    Type::builtin(BuiltinTypeId::List),
2041                    unifier.apply_type(&elem_tv),
2042                );
2043                let (p2, binds2) = infer_pattern(unifier, supply, env, tail, &tail_ty)?;
2044                preds.extend(p2);
2045                bindings.extend(binds2);
2046
2047                let bindings = bindings
2048                    .into_iter()
2049                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2050                    .collect();
2051                Ok((preds, bindings))
2052            }
2053            Pattern::Tuple(_, elems) => {
2054                let mut elem_tys: Vec<Type> = (0..elems.len())
2055                    .map(|i| Type::var(supply.fresh(Some(format!("t{i}").into()))))
2056                    .collect();
2057                let expected = Type::tuple(elem_tys.clone());
2058                unifier.unify(scrutinee_ty, &expected)?;
2059                elem_tys = elem_tys
2060                    .into_iter()
2061                    .map(|t| unifier.apply_type(&t))
2062                    .collect();
2063
2064                let mut preds = Vec::new();
2065                let mut bindings = Vec::new();
2066                for (p, ty) in elems.iter().zip(elem_tys.iter()) {
2067                    let (p_preds, p_binds) = infer_pattern(unifier, supply, env, p, ty)?;
2068                    preds.extend(p_preds);
2069                    bindings.extend(p_binds);
2070                }
2071                let bindings = bindings
2072                    .into_iter()
2073                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2074                    .collect();
2075                Ok((preds, bindings))
2076            }
2077            Pattern::Dict(_, fields) => {
2078                if let TypeKind::Record(ty_fields) = scrutinee_ty.as_ref() {
2079                    let mut preds = Vec::new();
2080                    let mut bindings = Vec::new();
2081                    for (key, pat) in fields {
2082                        let ty = ty_fields
2083                            .iter()
2084                            .find(|(name, _)| name == key)
2085                            .map(|(_, ty)| unifier.apply_type(ty))
2086                            .ok_or_else(|| TypeError::UnknownField {
2087                                field: key.clone(),
2088                                typ: scrutinee_ty.to_string(),
2089                            })?;
2090                        let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &ty)?;
2091                        preds.extend(p_preds);
2092                        bindings.extend(p_binds);
2093                    }
2094                    let bindings = bindings
2095                        .into_iter()
2096                        .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2097                        .collect();
2098                    Ok((preds, bindings))
2099                } else {
2100                    let elem_tv = Type::var(supply.fresh(Some("v".into())));
2101                    let dict_ty = Type::app(Type::builtin(BuiltinTypeId::Dict), elem_tv.clone());
2102                    unifier.unify(scrutinee_ty, &dict_ty)?;
2103                    let elem_ty = unifier.apply_type(&elem_tv);
2104
2105                    let mut preds = Vec::new();
2106                    let mut bindings = Vec::new();
2107                    for (_key, pat) in fields {
2108                        let (p_preds, p_binds) =
2109                            infer_pattern(unifier, supply, env, pat, &elem_ty)?;
2110                        preds.extend(p_preds);
2111                        bindings.extend(p_binds);
2112                    }
2113                    let bindings = bindings
2114                        .into_iter()
2115                        .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2116                        .collect();
2117                    Ok((preds, bindings))
2118                }
2119            }
2120        }
2121    })();
2122    res.map_err(|err| err.with_span(&span))
2123}
2124
2125fn type_head_name(typ: &Type) -> Option<&Symbol> {
2126    let mut cur = typ;
2127    while let TypeKind::App(head, _) = cur.as_ref() {
2128        cur = head;
2129    }
2130    match cur.as_ref() {
2131        TypeKind::Con(tc) => Some(&tc.name),
2132        _ => None,
2133    }
2134}
2135
2136fn adt_name_from_patterns(
2137    adts: &BTreeMap<Symbol, AdtDecl>,
2138    patterns: &[Pattern],
2139) -> Option<Symbol> {
2140    let mut candidate: Option<Symbol> = None;
2141    for pat in patterns {
2142        let next = match pat {
2143            Pattern::Named(_, name, _) => {
2144                let name_sym = name.to_dotted_symbol();
2145                ctor_lookup(adts, &name_sym).map(|(adt, _)| adt.name.clone())
2146            }
2147            Pattern::List(..) | Pattern::Cons(..) => Some(sym("List")),
2148            _ => None,
2149        };
2150        if let Some(next) = next {
2151            match &candidate {
2152                None => candidate = Some(next),
2153                Some(prev) if *prev == next => {}
2154                Some(_) => return None,
2155            }
2156        }
2157    }
2158    candidate
2159}
2160
2161fn check_match_exhaustive(
2162    adts: &BTreeMap<Symbol, AdtDecl>,
2163    scrutinee_ty: &Type,
2164    patterns: &[Pattern],
2165) -> Result<(), TypeError> {
2166    if patterns
2167        .iter()
2168        .any(|p| matches!(p, Pattern::Wildcard(..) | Pattern::Var(_)))
2169    {
2170        return Ok(());
2171    }
2172    let adt_name = match type_head_name(scrutinee_ty).cloned() {
2173        Some(name) => name,
2174        None => match adt_name_from_patterns(adts, patterns) {
2175            Some(name) => name,
2176            None => return Ok(()),
2177        },
2178    };
2179    let adt = match adts.get(&adt_name) {
2180        Some(adt) => adt,
2181        None => return Ok(()),
2182    };
2183    let ctor_names: BTreeSet<Symbol> = adt.variants.iter().map(|v| v.name.clone()).collect();
2184    if ctor_names.is_empty() {
2185        return Ok(());
2186    }
2187    let mut covered = BTreeSet::new();
2188    for pat in patterns {
2189        match pat {
2190            Pattern::Named(_, name, _) => {
2191                let name_sym = name.to_dotted_symbol();
2192                if ctor_names.contains(&name_sym) {
2193                    covered.insert(name_sym);
2194                }
2195            }
2196            Pattern::List(_, elems) if adt_name.as_ref() == "List" && elems.is_empty() => {
2197                covered.insert(sym("Empty"));
2198            }
2199            Pattern::Cons(..) if adt_name.as_ref() == "List" => {
2200                covered.insert(sym("Cons"));
2201            }
2202            _ => {}
2203        }
2204    }
2205    let mut missing: Vec<Symbol> = ctor_names.difference(&covered).cloned().collect();
2206    if missing.is_empty() {
2207        return Ok(());
2208    }
2209    missing.sort();
2210    Err(TypeError::NonExhaustiveMatch {
2211        typ: scrutinee_ty.to_string(),
2212        missing,
2213    })
2214}
2215
2216#[cfg(test)]
2217mod tests {
2218    use super::*;
2219    use crate::{
2220        types::collect_adts_in_types,
2221        typesystem::{TypeSystemLimits, entails, generalize},
2222        unification::bind,
2223    };
2224    use rex_lexer::{Token, span::Span};
2225    use rex_parser::Parser;
2226    use rex_util::{GasCosts, GasMeter};
2227
2228    fn tvar(id: TypeVarId, name: &str) -> Type {
2229        Type::var(TypeVar::new(id, Some(sym(name))))
2230    }
2231
2232    fn dict_of(elem: Type) -> Type {
2233        Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
2234    }
2235
2236    #[test]
2237    fn unify_simple() {
2238        let t1 = Type::fun(tvar(0, "a"), Type::builtin(BuiltinTypeId::U32));
2239        let t2 = Type::fun(Type::builtin(BuiltinTypeId::U16), tvar(1, "b"));
2240        let subst = unify(&t1, &t2).unwrap();
2241        assert_eq!(subst.get(&0), Some(&Type::builtin(BuiltinTypeId::U16)));
2242        assert_eq!(subst.get(&1), Some(&Type::builtin(BuiltinTypeId::U32)));
2243    }
2244
2245    #[test]
2246    fn occurs_check_blocks_infinite_type() {
2247        let tv = TypeVar::new(0, Some(sym("a")));
2248        let t = Type::fun(Type::var(tv.clone()), Type::builtin(BuiltinTypeId::U8));
2249        let err = bind(&tv, &t).unwrap_err();
2250        assert!(matches!(err, TypeError::Occurs(_, _)));
2251    }
2252
2253    #[test]
2254    fn instantiate_and_generalize_round_trip() {
2255        let mut supply = TypeVarSupply::new();
2256        let a = Type::var(supply.fresh(Some(sym("a"))));
2257        let scheme = generalize(&TypeEnv::new(), vec![], Type::fun(a.clone(), a.clone()));
2258        let (preds, inst) = instantiate(&scheme, &mut supply);
2259        assert!(preds.is_empty());
2260        if let TypeKind::Fun(l, r) = inst.as_ref() {
2261            match (l.as_ref(), r.as_ref()) {
2262                (TypeKind::Var(_), TypeKind::Var(_)) => {}
2263                _ => panic!("expected polymorphic identity"),
2264            }
2265        } else {
2266            panic!("expected function type");
2267        }
2268    }
2269
2270    #[test]
2271    fn entail_superclasses() {
2272        let ts = TypeSystem::new_with_prelude().unwrap();
2273        let pred = Predicate::new("Semiring", Type::builtin(BuiltinTypeId::I32));
2274        let given = [Predicate::new(
2275            "AdditiveGroup",
2276            Type::builtin(BuiltinTypeId::I32),
2277        )];
2278        assert!(entails(&ts.classes, &given, &pred).unwrap());
2279    }
2280
2281    #[test]
2282    fn entail_instances() {
2283        let ts = TypeSystem::new_with_prelude().unwrap();
2284        let pred = Predicate::new("Field", Type::builtin(BuiltinTypeId::F32));
2285        assert!(entails(&ts.classes, &[], &pred).unwrap());
2286
2287        let pred_fail = Predicate::new("Field", Type::builtin(BuiltinTypeId::U32));
2288        assert!(!entails(&ts.classes, &[], &pred_fail).unwrap());
2289    }
2290
2291    #[test]
2292    fn prelude_injects_functions() {
2293        let ts = TypeSystem::new_with_prelude().unwrap();
2294        let minus = ts.env.lookup(&sym("-")).expect("minus in env");
2295        let div = ts.env.lookup(&sym("/")).expect("div in env");
2296        assert_eq!(minus.len(), 1);
2297        assert_eq!(div.len(), 1);
2298        let minus = &minus[0];
2299        let div = &div[0];
2300        assert_eq!(minus.preds.len(), 1);
2301        assert_eq!(minus.vars.len(), 1);
2302        assert_eq!(div.preds.len(), 1);
2303        assert_eq!(div.vars.len(), 1);
2304    }
2305
2306    #[test]
2307    fn adt_constructors_are_present() {
2308        let ts = TypeSystem::new_with_prelude().unwrap();
2309        assert!(ts.env.lookup(&sym("Empty")).is_some());
2310        assert!(ts.env.lookup(&sym("Cons")).is_some());
2311        assert!(ts.env.lookup(&sym("Ok")).is_some());
2312        assert!(ts.env.lookup(&sym("Err")).is_some());
2313        assert!(ts.env.lookup(&sym("Some")).is_some());
2314        assert!(ts.env.lookup(&sym("None")).is_some());
2315    }
2316
2317    fn parse_expr(code: &str) -> std::sync::Arc<rex_ast::expr::Expr> {
2318        let mut parser = Parser::new(Token::tokenize(code).unwrap());
2319        parser.parse_program(&mut GasMeter::default()).unwrap().expr
2320    }
2321
2322    fn parse_program(code: &str) -> rex_ast::expr::Program {
2323        let mut parser = Parser::new(Token::tokenize(code).unwrap());
2324        parser.parse_program(&mut GasMeter::default()).unwrap()
2325    }
2326
2327    #[test]
2328    fn infer_deep_list_does_not_overflow() {
2329        const N: usize = 40;
2330        let mut code = String::new();
2331        code.push_str("let xs = ");
2332        for _ in 0..N {
2333            code.push_str("Cons 0 (");
2334        }
2335        code.push_str("Empty");
2336        for _ in 0..N {
2337            code.push(')');
2338        }
2339        code.push_str(" in xs");
2340
2341        let parse_handle = std::thread::Builder::new()
2342            .name("infer_deep_list_parse".into())
2343            .stack_size(128 * 1024 * 1024)
2344            .spawn(move || {
2345                let tokens = Token::tokenize(&code).unwrap();
2346                let mut parser = Parser::new(tokens);
2347                parser.parse_program(&mut GasMeter::default())
2348            })
2349            .unwrap();
2350        let program = parse_handle.join().unwrap().unwrap();
2351        let expr = program.expr;
2352        let mut ts = TypeSystem::new_with_prelude().unwrap();
2353        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2354        assert_eq!(
2355            ty,
2356            Type::app(
2357                Type::builtin(BuiltinTypeId::List),
2358                Type::builtin(BuiltinTypeId::I32)
2359            )
2360        );
2361    }
2362
2363    #[test]
2364    fn collect_adts_in_types_finds_nested_unique_adts() {
2365        let foo = Type::user_con("Foo", 1);
2366        let bar = Type::user_con("Bar", 0);
2367        let ty = Type::fun(
2368            Type::app(
2369                Type::builtin(BuiltinTypeId::List),
2370                Type::app(foo.clone(), tvar(0, "a")),
2371            ),
2372            Type::tuple(vec![
2373                Type::app(foo.clone(), Type::builtin(BuiltinTypeId::I32)),
2374                bar.clone(),
2375            ]),
2376        );
2377
2378        let adts = collect_adts_in_types(vec![ty]).unwrap();
2379        assert_eq!(adts, vec![foo, bar]);
2380    }
2381
2382    #[test]
2383    fn collect_adts_in_types_rejects_conflicting_names() {
2384        let arity1 = Type::user_con("Thing", 1);
2385        let arity2 = Type::user_con("Thing", 2);
2386
2387        let err = collect_adts_in_types(vec![arity1.clone(), arity2.clone()]).unwrap_err();
2388        assert_eq!(err.conflicts.len(), 1);
2389        let conflict = &err.conflicts[0];
2390        assert_eq!(conflict.name, sym("Thing"));
2391        assert_eq!(conflict.definitions, vec![arity1, arity2]);
2392    }
2393
2394    #[test]
2395    fn infer_depth_limit_is_enforced() {
2396        const N: usize = 40;
2397        let mut code = String::new();
2398        code.push_str("let xs = ");
2399        for _ in 0..N {
2400            code.push_str("Cons 0 (");
2401        }
2402        code.push_str("Empty");
2403        for _ in 0..N {
2404            code.push(')');
2405        }
2406        code.push_str(" in xs");
2407
2408        let program = parse_program(&code);
2409        let mut ts = TypeSystem::new_with_prelude().unwrap();
2410        ts.set_limits(TypeSystemLimits {
2411            max_infer_depth: Some(8),
2412        });
2413
2414        let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
2415        assert!(
2416            err.to_string().contains("maximum inference depth exceeded"),
2417            "expected a max-depth inference error, got: {err:?}"
2418        );
2419    }
2420
2421    #[test]
2422    fn declare_fn_injects_scheme_for_use_sites() {
2423        let program = parse_program(
2424            r#"
2425            declare fn id x: a -> a
2426            id 1
2427            "#,
2428        );
2429        let mut ts = TypeSystem::new_with_prelude().unwrap();
2430        ts.register_decls(&program.decls).unwrap();
2431        let (preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2432        assert!(
2433            preds.is_empty()
2434                || preds.iter().all(|p| p.class.as_ref() == "Integral"
2435                    && p.typ == Type::builtin(BuiltinTypeId::I32))
2436        );
2437        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2438    }
2439
2440    #[test]
2441    fn declare_fn_is_noop_when_matching_existing_scheme() {
2442        let mut ts = TypeSystem::new_with_prelude().unwrap();
2443        ts.add_value(
2444            "foo",
2445            Scheme::new(
2446                vec![],
2447                vec![],
2448                Type::fun(
2449                    Type::builtin(BuiltinTypeId::I32),
2450                    Type::builtin(BuiltinTypeId::I32),
2451                ),
2452            ),
2453        );
2454
2455        let program = parse_program(
2456            r#"
2457            declare fn foo x: i32 -> i32
2458            0
2459            "#,
2460        );
2461        let rex_ast::expr::Decl::DeclareFn(fd) = &program.decls[0] else {
2462            panic!("expected declare fn decl");
2463        };
2464        ts.inject_declare_fn_decl(fd).unwrap();
2465    }
2466
2467    #[test]
2468    fn unit_type_parses_and_infers() {
2469        let program = parse_program(
2470            r#"
2471            fn unit_id x: () -> () = x
2472            unit_id ()
2473            "#,
2474        );
2475        let mut ts = TypeSystem::new_with_prelude().unwrap();
2476        ts.register_decls(&program.decls).unwrap();
2477        let (preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2478        assert!(preds.is_empty());
2479        assert_eq!(ty, Type::tuple(vec![]));
2480    }
2481
2482    fn strip_span(mut err: TypeError) -> TypeError {
2483        while let TypeError::Spanned { error, .. } = err {
2484            err = *error;
2485        }
2486        err
2487    }
2488
2489    #[test]
2490    fn type_errors_include_span() {
2491        let expr = parse_expr("missing");
2492        let mut ts = TypeSystem::new_with_prelude().unwrap();
2493        let err = infer(&mut ts, expr.as_ref()).unwrap_err();
2494        match err {
2495            TypeError::Spanned { span, error } => {
2496                assert_ne!(span, Span::default());
2497                assert!(matches!(
2498                    *error,
2499                    TypeError::UnknownVar(name) if name.as_ref() == "missing"
2500                ));
2501            }
2502            other => panic!("expected spanned error, got {other:?}"),
2503        }
2504    }
2505
2506    #[test]
2507    fn infer_with_gas_rejects_out_of_budget() {
2508        let expr = parse_expr("1");
2509        let mut ts = TypeSystem::new_with_prelude().unwrap();
2510        let mut gas = GasMeter::new(
2511            Some(0),
2512            GasCosts {
2513                infer_node: 1,
2514                unify_step: 0,
2515                ..GasCosts::sensible_defaults()
2516            },
2517        );
2518        let err = infer_with_gas(&mut ts, expr.as_ref(), &mut gas).unwrap_err();
2519        assert!(matches!(strip_span(err), TypeError::OutOfGas(..)));
2520    }
2521
2522    #[test]
2523    fn reject_user_redefinition_of_primitive_type_name() {
2524        let program = parse_program("type i32 = I32Wrap i32");
2525        let mut ts = TypeSystem::new_with_prelude().unwrap();
2526        let rex_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2527            panic!("expected type decl");
2528        };
2529        let err = ts.register_type_decl(decl).unwrap_err();
2530        assert!(matches!(
2531            err,
2532            TypeError::ReservedTypeName(name) if name.as_ref() == "i32"
2533        ));
2534    }
2535
2536    #[test]
2537    fn reject_user_redefinition_of_prelude_adt_name() {
2538        let program = parse_program("type Result e a = Nope e a");
2539        let mut ts = TypeSystem::new_with_prelude().unwrap();
2540        let rex_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2541            panic!("expected type decl");
2542        };
2543        let err = ts.register_type_decl(decl).unwrap_err();
2544        assert!(matches!(
2545            err,
2546            TypeError::ReservedTypeName(name) if name.as_ref() == "Result"
2547        ));
2548    }
2549
2550    #[test]
2551    fn reject_user_redefinition_of_promise_type_name() {
2552        let program = parse_program("type Promise a = PromiseWrap a");
2553        let mut ts = TypeSystem::new_with_prelude().unwrap();
2554        let rex_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2555            panic!("expected type decl");
2556        };
2557        let err = ts.register_type_decl(decl).unwrap_err();
2558        assert!(matches!(
2559            err,
2560            TypeError::ReservedTypeName(name) if name.as_ref() == "Promise"
2561        ));
2562    }
2563
2564    #[test]
2565    fn infer_polymorphic_id_tuple() {
2566        let expr = parse_expr(
2567            r#"
2568            let
2569                id = \x -> x
2570            in
2571                id (id 420, id 6.9, id "str")
2572            "#,
2573        );
2574        let mut ts = TypeSystem::new_with_prelude().unwrap();
2575        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2576        let expected = Type::tuple(vec![
2577            Type::builtin(BuiltinTypeId::I32),
2578            Type::builtin(BuiltinTypeId::F32),
2579            Type::builtin(BuiltinTypeId::String),
2580        ]);
2581        assert_eq!(ty, expected);
2582    }
2583
2584    #[test]
2585    fn infer_type_annotation_ok() {
2586        let expr = parse_expr("let x: i32 = 42 in x");
2587        let mut ts = TypeSystem::new_with_prelude().unwrap();
2588        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2589        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2590    }
2591
2592    #[test]
2593    fn infer_type_annotation_lambda_param() {
2594        let expr = parse_expr("\\ (a : f32) -> a");
2595        let mut ts = TypeSystem::new_with_prelude().unwrap();
2596        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2597        assert_eq!(
2598            ty,
2599            Type::fun(
2600                Type::builtin(BuiltinTypeId::F32),
2601                Type::builtin(BuiltinTypeId::F32)
2602            )
2603        );
2604    }
2605
2606    #[test]
2607    fn infer_type_annotation_is_alias() {
2608        let expr = parse_expr("\"hi\" is str");
2609        let mut ts = TypeSystem::new_with_prelude().unwrap();
2610        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2611        assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
2612    }
2613
2614    #[test]
2615    fn infer_type_annotation_with_promise_constructor() {
2616        let expr = parse_expr("\\(x: Promise i32) -> x");
2617        let mut ts = TypeSystem::new_with_prelude().unwrap();
2618        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2619        let promise_i32 = Type::promise(Type::builtin(BuiltinTypeId::I32));
2620        assert_eq!(ty, Type::fun(promise_i32.clone(), promise_i32));
2621    }
2622
2623    #[test]
2624    fn infer_type_annotation_mismatch_error() {
2625        let expr = parse_expr("let x: i32 = 3.14 in x");
2626        let mut ts = TypeSystem::new_with_prelude().unwrap();
2627        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
2628        assert!(matches!(err, TypeError::Unification(_, _)));
2629    }
2630
2631    #[test]
2632    fn infer_project_single_variant_let() {
2633        let program = parse_program(
2634            r#"
2635            type MyADT = MyVariant1 { field1: i32, field2: f32 }
2636            let
2637                x = MyVariant1 { field1 = 1, field2 = 2.0 }
2638            in
2639                (x.field1, x.field2)
2640            "#,
2641        );
2642        let mut ts = TypeSystem::new_with_prelude().unwrap();
2643        for decl in &program.decls {
2644            if let rex_ast::expr::Decl::Type(decl) = decl {
2645                ts.register_type_decl(decl).unwrap();
2646            }
2647        }
2648        let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2649        let expected = Type::tuple(vec![
2650            Type::builtin(BuiltinTypeId::I32),
2651            Type::builtin(BuiltinTypeId::F32),
2652        ]);
2653        assert_eq!(ty, expected);
2654    }
2655
2656    #[test]
2657    fn infer_project_known_variant_let() {
2658        let program = parse_program(
2659            r#"
2660            type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
2661            let
2662                x = MyVariant1 { field1 = 1, field2 = 2.0 }
2663            in
2664                x.field1
2665            "#,
2666        );
2667        let mut ts = TypeSystem::new_with_prelude().unwrap();
2668        for decl in &program.decls {
2669            if let rex_ast::expr::Decl::Type(decl) = decl {
2670                ts.register_type_decl(decl).unwrap();
2671            }
2672        }
2673        let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2674        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2675    }
2676
2677    #[test]
2678    fn infer_project_unknown_variant_error() {
2679        let program = parse_program(
2680            r#"
2681            type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
2682            let
2683                x = MyVariant2 1 2.0
2684            in
2685                x.field1
2686            "#,
2687        );
2688        let mut ts = TypeSystem::new_with_prelude().unwrap();
2689        for decl in &program.decls {
2690            if let rex_ast::expr::Decl::Type(decl) = decl {
2691                ts.register_type_decl(decl).unwrap();
2692            }
2693        }
2694        let err = strip_span(infer(&mut ts, program.expr.as_ref()).unwrap_err());
2695        assert!(matches!(err, TypeError::FieldNotKnown { .. }));
2696    }
2697
2698    #[test]
2699    fn infer_project_lambda_param_single_variant() {
2700        let program = parse_program(
2701            r#"
2702            type Boxed = Boxed { value: i32 }
2703            let
2704                f = \x -> x.value
2705            in
2706                f (Boxed { value = 1 })
2707            "#,
2708        );
2709        let mut ts = TypeSystem::new_with_prelude().unwrap();
2710        for decl in &program.decls {
2711            if let rex_ast::expr::Decl::Type(decl) = decl {
2712                ts.register_type_decl(decl).unwrap();
2713            }
2714        }
2715        let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2716        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2717    }
2718
2719    #[test]
2720    fn infer_project_in_match_arm() {
2721        let program = parse_program(
2722            r#"
2723            type MyADT = MyVariant1 { field1: i32 } | MyVariant2 i32
2724            let
2725                x = MyVariant1 { field1 = 1 }
2726            in
2727                match x
2728                    when MyVariant1 { field1 } -> x.field1
2729                    when MyVariant2 _ -> 0
2730            "#,
2731        );
2732        let mut ts = TypeSystem::new_with_prelude().unwrap();
2733        for decl in &program.decls {
2734            if let rex_ast::expr::Decl::Type(decl) = decl {
2735                ts.register_type_decl(decl).unwrap();
2736            }
2737        }
2738        let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2739        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2740    }
2741
2742    #[test]
2743    fn infer_nested_let_lambda_match_option() {
2744        let expr = parse_expr(
2745            r#"
2746            let
2747                choose = \flag a b -> if flag then a else b,
2748                build = \flag ->
2749                    let
2750                        pick = choose flag,
2751                        val = pick 1 2
2752                    in
2753                        Some val
2754            in
2755                match (build true)
2756                    when Some x -> x
2757                    when None -> 0
2758            "#,
2759        );
2760        let mut ts = TypeSystem::new_with_prelude().unwrap();
2761        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2762        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2763    }
2764
2765    #[test]
2766    fn infer_polymorphic_apply_in_tuple() {
2767        let expr = parse_expr(
2768            r#"
2769            let
2770                apply = \f x -> f x,
2771                id = \x -> x,
2772                wrap = \x -> (x, x)
2773            in
2774                (apply id 1, apply id "hi", apply wrap true)
2775            "#,
2776        );
2777        let mut ts = TypeSystem::new_with_prelude().unwrap();
2778        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2779        let expected = Type::tuple(vec![
2780            Type::builtin(BuiltinTypeId::I32),
2781            Type::builtin(BuiltinTypeId::String),
2782            Type::tuple(vec![
2783                Type::builtin(BuiltinTypeId::Bool),
2784                Type::builtin(BuiltinTypeId::Bool),
2785            ]),
2786        ]);
2787        assert_eq!(ty, expected);
2788    }
2789
2790    #[test]
2791    fn infer_nested_result_option_match() {
2792        let expr = parse_expr(
2793            r#"
2794            let
2795                unwrap = \x ->
2796                    match x
2797                        when Ok (Some v) -> v
2798                        when Ok None -> 0
2799                        when Err _ -> 0
2800            in
2801                unwrap (Ok (Some 5))
2802            "#,
2803        );
2804        let mut ts = TypeSystem::new_with_prelude().unwrap();
2805        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2806        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2807    }
2808
2809    #[test]
2810    fn infer_head_or_list_match() {
2811        let expr = parse_expr(
2812            r#"
2813            let
2814                head_or = \fallback xs ->
2815                    match xs
2816                        when [] -> fallback
2817                        when x::xs -> x
2818            in
2819                (head_or 0 [1, 2, 3], head_or 0 [])
2820            "#,
2821        );
2822        let mut ts = TypeSystem::new_with_prelude().unwrap();
2823        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2824        let expected = Type::tuple(vec![
2825            Type::builtin(BuiltinTypeId::I32),
2826            Type::builtin(BuiltinTypeId::I32),
2827        ]);
2828        assert_eq!(ty, expected);
2829    }
2830
2831    #[test]
2832    fn infer_head_or_list_match_cons_constructor_form() {
2833        let expr = parse_expr(
2834            r#"
2835            let
2836                head_or = \fallback xs ->
2837                    match xs
2838                        when [] -> fallback
2839                        when Cons x xs1 -> x
2840            in
2841                (head_or 0 (Cons 1 (Cons 2 Empty)), head_or 0 Empty)
2842            "#,
2843        );
2844        let mut ts = TypeSystem::new_with_prelude().unwrap();
2845        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2846        let expected = Type::tuple(vec![
2847            Type::builtin(BuiltinTypeId::I32),
2848            Type::builtin(BuiltinTypeId::I32),
2849        ]);
2850        assert_eq!(ty, expected);
2851    }
2852
2853    #[test]
2854    fn infer_record_pattern_in_lambda() {
2855        let program = parse_program(
2856            r#"
2857            type Pair = Pair { left: i32, right: i32 }
2858            let
2859                sum = \p ->
2860                    match p
2861                        when Pair { left, right } -> left + right
2862            in
2863                sum (Pair { left = 1, right = 2 })
2864            "#,
2865        );
2866        let mut ts = TypeSystem::new_with_prelude().unwrap();
2867        for decl in &program.decls {
2868            if let rex_ast::expr::Decl::Type(decl) = decl {
2869                ts.register_type_decl(decl).unwrap();
2870            }
2871        }
2872        let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2873        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2874    }
2875
2876    #[test]
2877    fn infer_fn_decl_simple() {
2878        let program = parse_program(
2879            r#"
2880            fn add (x: i32, y: i32) -> i32 = x + y
2881            add 1 2
2882            "#,
2883        );
2884        let mut ts = TypeSystem::new_with_prelude().unwrap();
2885        let expr = program.expr_with_fns();
2886        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2887        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2888    }
2889
2890    #[test]
2891    fn infer_fn_decl_signature_form() {
2892        let program = parse_program(
2893            r#"
2894            fn add : i32 -> i32 -> i32 = \x y -> x + y
2895            add 1 2
2896            "#,
2897        );
2898        let mut ts = TypeSystem::new_with_prelude().unwrap();
2899        let expr = program.expr_with_fns();
2900        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2901        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2902    }
2903
2904    #[test]
2905    fn infer_fn_decl_polymorphic_where_constraints() {
2906        let program = parse_program(
2907            r#"
2908            fn my_add (x: a, y: a) -> a where AdditiveMonoid a = x + y
2909            (my_add 1 2, my_add 1.0 2.0)
2910            "#,
2911        );
2912        let mut ts = TypeSystem::new_with_prelude().unwrap();
2913        let expr = program.expr_with_fns();
2914        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2915        assert_eq!(
2916            ty,
2917            Type::tuple(vec![
2918                Type::builtin(BuiltinTypeId::I32),
2919                Type::builtin(BuiltinTypeId::F32)
2920            ])
2921        );
2922    }
2923
2924    #[test]
2925    fn infer_additive_monoid_constraint() {
2926        let expr = parse_expr("\\x y -> x + y");
2927        let mut ts = TypeSystem::new_with_prelude().unwrap();
2928        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2929        assert_eq!(preds.len(), 1);
2930        assert_eq!(preds[0].class.as_ref(), "AdditiveMonoid");
2931
2932        if let TypeKind::Fun(a, rest) = ty.as_ref()
2933            && let TypeKind::Fun(b, c) = rest.as_ref()
2934        {
2935            assert_eq!(a.as_ref(), b.as_ref());
2936            assert_eq!(b.as_ref(), c.as_ref());
2937            assert_eq!(preds[0].typ, a.clone());
2938            return;
2939        }
2940        panic!("expected a -> a -> a");
2941    }
2942
2943    #[test]
2944    fn infer_multiplicative_monoid_constraint() {
2945        let expr = parse_expr("\\x y -> x * y");
2946        let mut ts = TypeSystem::new_with_prelude().unwrap();
2947        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2948        assert_eq!(preds.len(), 1);
2949        assert_eq!(preds[0].class.as_ref(), "MultiplicativeMonoid");
2950
2951        if let TypeKind::Fun(a, rest) = ty.as_ref()
2952            && let TypeKind::Fun(b, c) = rest.as_ref()
2953        {
2954            assert_eq!(a.as_ref(), b.as_ref());
2955            assert_eq!(b.as_ref(), c.as_ref());
2956            assert_eq!(preds[0].typ, a.clone());
2957            return;
2958        }
2959        panic!("expected a -> a -> a");
2960    }
2961
2962    #[test]
2963    fn infer_additive_group_constraint() {
2964        let expr = parse_expr("\\x y -> x - y");
2965        let mut ts = TypeSystem::new_with_prelude().unwrap();
2966        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2967        assert_eq!(preds.len(), 1);
2968        assert_eq!(preds[0].class.as_ref(), "AdditiveGroup");
2969
2970        if let TypeKind::Fun(a, rest) = ty.as_ref()
2971            && let TypeKind::Fun(b, c) = rest.as_ref()
2972        {
2973            assert_eq!(a.as_ref(), b.as_ref());
2974            assert_eq!(b.as_ref(), c.as_ref());
2975            assert_eq!(preds[0].typ, a.clone());
2976            return;
2977        }
2978        panic!("expected a -> a -> a");
2979    }
2980
2981    #[test]
2982    fn infer_integral_constraint() {
2983        let expr = parse_expr("\\x y -> x % y");
2984        let mut ts = TypeSystem::new_with_prelude().unwrap();
2985        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2986        assert_eq!(preds.len(), 1);
2987        assert_eq!(preds[0].class.as_ref(), "Integral");
2988
2989        if let TypeKind::Fun(a, rest) = ty.as_ref()
2990            && let TypeKind::Fun(b, c) = rest.as_ref()
2991        {
2992            assert_eq!(a.as_ref(), b.as_ref());
2993            assert_eq!(b.as_ref(), c.as_ref());
2994            assert_eq!(preds[0].typ, a.clone());
2995            return;
2996        }
2997        panic!("expected a -> a -> a");
2998    }
2999
3000    #[test]
3001    fn infer_literal_addition_defaults() {
3002        let expr = parse_expr("1 + 2");
3003        let mut ts = TypeSystem::new_with_prelude().unwrap();
3004        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3005        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3006        assert_eq!(preds.len(), 2);
3007        assert!(preds.iter().any(|p| p.class.as_ref() == "AdditiveMonoid"));
3008        assert!(preds.iter().any(|p| p.class.as_ref() == "Integral"));
3009        assert!(
3010            preds
3011                .iter()
3012                .all(|p| p.typ == Type::builtin(BuiltinTypeId::I32))
3013        );
3014    }
3015
3016    #[test]
3017    fn infer_mod_defaults() {
3018        let expr = parse_expr("1 % 2");
3019        let mut ts = TypeSystem::new_with_prelude().unwrap();
3020        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3021        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3022        assert_eq!(preds.len(), 1);
3023        assert_eq!(preds[0].class.as_ref(), "Integral");
3024        assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::I32));
3025    }
3026
3027    #[test]
3028    fn infer_get_list_type() {
3029        let expr = parse_expr("get 1 [1, 2, 3]");
3030        let mut ts = TypeSystem::new_with_prelude().unwrap();
3031        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3032        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3033        assert!(preds.iter().any(|p| p.class.as_ref() == "Indexable"));
3034        assert!(preds.iter().all(|p| {
3035            p.class.as_ref() == "Indexable"
3036                || (p.class.as_ref() == "Integral" && p.typ == Type::builtin(BuiltinTypeId::I32))
3037        }));
3038        for pred in preds.iter().filter(|p| p.class.as_ref() == "Indexable") {
3039            assert!(entails(&ts.classes, &[], pred).unwrap());
3040        }
3041    }
3042
3043    #[test]
3044    fn infer_get_tuple_type() {
3045        let expr = parse_expr("(1, 'Hello', true).0");
3046        let mut ts = TypeSystem::new_with_prelude().unwrap();
3047        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3048        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3049        assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3050
3051        let expr = parse_expr("(1, 'Hello', true).1");
3052        let mut ts = TypeSystem::new_with_prelude().unwrap();
3053        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3054        assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
3055        assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3056
3057        let expr = parse_expr("(1, 'Hello', true).2");
3058        let mut ts = TypeSystem::new_with_prelude().unwrap();
3059        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3060        assert_eq!(ty, Type::builtin(BuiltinTypeId::Bool));
3061        assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3062    }
3063
3064    #[test]
3065    fn infer_division_defaults() {
3066        let expr = parse_expr("1.0 / 2.0");
3067        let mut ts = TypeSystem::new_with_prelude().unwrap();
3068        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3069        assert_eq!(ty, Type::builtin(BuiltinTypeId::F32));
3070        assert_eq!(preds.len(), 1);
3071        assert_eq!(preds[0].class.as_ref(), "Field");
3072        assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::F32));
3073        assert!(entails(&ts.classes, &[], &preds[0]).unwrap());
3074    }
3075
3076    #[test]
3077    fn infer_unbound_variable_error() {
3078        let expr = parse_expr("missing");
3079        let mut ts = TypeSystem::new_with_prelude().unwrap();
3080        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3081        assert!(matches!(
3082            err,
3083            TypeError::UnknownVar(name) if name.as_ref() == "missing"
3084        ));
3085    }
3086
3087    #[test]
3088    fn infer_if_branch_type_mismatch_error() {
3089        let expr = parse_expr(r#"if true then 1 else "no""#);
3090        let mut ts = TypeSystem::new_with_prelude().unwrap();
3091        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3092        match err {
3093            TypeError::Unification(a, b) => {
3094                let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
3095                assert!(ok, "expected i32 vs string, got {a} vs {b}");
3096            }
3097            other => panic!("expected unification error, got {other:?}"),
3098        }
3099    }
3100
3101    #[test]
3102    fn infer_unknown_pattern_constructor_error() {
3103        let expr = parse_expr("match 1 when Nope -> 1");
3104        let mut ts = TypeSystem::new_with_prelude().unwrap();
3105        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3106        assert!(matches!(
3107            err,
3108            TypeError::UnknownVar(name) if name.as_ref() == "Nope"
3109        ));
3110    }
3111
3112    #[test]
3113    fn infer_ambiguous_overload_error() {
3114        let mut ts = TypeSystem::new();
3115        let a = TypeVar::new(0, Some(sym("a")));
3116        let b = TypeVar::new(1, Some(sym("b")));
3117        let scheme_a = Scheme::new(vec![a.clone()], vec![], Type::var(a));
3118        let scheme_b = Scheme::new(vec![b.clone()], vec![], Type::var(b));
3119        ts.add_overload(sym("dup"), scheme_a);
3120        ts.add_overload(sym("dup"), scheme_b);
3121        let expr = parse_expr("dup");
3122        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3123        assert!(matches!(
3124            err,
3125            TypeError::AmbiguousOverload(name) if name.as_ref() == "dup"
3126        ));
3127    }
3128
3129    #[test]
3130    fn infer_if_cond_not_bool_error() {
3131        let expr = parse_expr("if 1 then 2 else 3");
3132        let mut ts = TypeSystem::new_with_prelude().unwrap();
3133        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3134        match err {
3135            TypeError::Unification(a, b) => {
3136                let ok = (a == "bool" && b == "i32") || (a == "i32" && b == "bool");
3137                assert!(ok, "expected bool vs i32, got {a} vs {b}");
3138            }
3139            other => panic!("expected unification error, got {other:?}"),
3140        }
3141    }
3142
3143    #[test]
3144    fn infer_apply_non_function_error() {
3145        let expr = parse_expr("1 2");
3146        let mut ts = TypeSystem::new_with_prelude().unwrap();
3147        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3148        assert!(matches!(err, TypeError::Unification(_, _)));
3149    }
3150
3151    #[test]
3152    fn infer_list_element_mismatch_error() {
3153        let expr = parse_expr("[1, true]");
3154        let mut ts = TypeSystem::new_with_prelude().unwrap();
3155        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3156        match err {
3157            TypeError::Unification(a, b) => {
3158                let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
3159                assert!(ok, "expected i32 vs bool, got {a} vs {b}");
3160            }
3161            other => panic!("expected unification error, got {other:?}"),
3162        }
3163    }
3164
3165    #[test]
3166    fn infer_dict_value_mismatch_error() {
3167        let expr = parse_expr("{a = 1, b = true}");
3168        let mut ts = TypeSystem::new_with_prelude().unwrap();
3169        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3170        match err {
3171            TypeError::Unification(a, b) => {
3172                let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
3173                assert!(ok, "expected i32 vs bool, got {a} vs {b}");
3174            }
3175            other => panic!("expected unification error, got {other:?}"),
3176        }
3177    }
3178
3179    #[test]
3180    fn infer_match_list_on_non_list_error() {
3181        let expr = parse_expr("match 1 when [x] -> x");
3182        let mut ts = TypeSystem::new_with_prelude().unwrap();
3183        assert!(infer(&mut ts, expr.as_ref()).is_err());
3184    }
3185
3186    #[test]
3187    fn infer_pattern_constructor_arity_error() {
3188        let expr = parse_expr("match (Ok 1) when Ok x y -> x");
3189        let mut ts = TypeSystem::new_with_prelude().unwrap();
3190        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3191        assert!(matches!(
3192            err,
3193            TypeError::UnsupportedExpr("pattern constructor")
3194        ));
3195    }
3196
3197    #[test]
3198    fn infer_match_arm_type_mismatch_error() {
3199        let expr = parse_expr(r#"match 1 when _ -> 1 when _ -> "no""#);
3200        let mut ts = TypeSystem::new_with_prelude().unwrap();
3201        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3202        match err {
3203            TypeError::Unification(a, b) => {
3204                let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
3205                assert!(ok, "expected i32 vs string, got {a} vs {b}");
3206            }
3207            other => panic!("expected unification error, got {other:?}"),
3208        }
3209    }
3210
3211    #[test]
3212    fn infer_match_option_on_non_option_error() {
3213        let expr = parse_expr("match 1 when Some x -> x");
3214        let mut ts = TypeSystem::new_with_prelude().unwrap();
3215        assert!(infer(&mut ts, expr.as_ref()).is_err());
3216    }
3217
3218    #[test]
3219    fn infer_dict_pattern_on_non_dict_error() {
3220        let expr = parse_expr("match 1 when {a} -> a");
3221        let mut ts = TypeSystem::new_with_prelude().unwrap();
3222        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3223        assert!(matches!(err, TypeError::Unification(_, _)));
3224    }
3225
3226    #[test]
3227    fn infer_cons_pattern_on_non_list_error() {
3228        let expr = parse_expr("match 1 when x::xs -> x");
3229        let mut ts = TypeSystem::new_with_prelude().unwrap();
3230        assert!(infer(&mut ts, expr.as_ref()).is_err());
3231    }
3232
3233    #[test]
3234    fn infer_apply_wrong_arg_type_error() {
3235        let expr = parse_expr("(\\x -> x + 1) true");
3236        let mut ts = TypeSystem::new_with_prelude().unwrap();
3237        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3238        assert!(matches!(err, TypeError::Unification(_, _)));
3239    }
3240
3241    #[test]
3242    fn infer_self_application_occurs_error() {
3243        let expr = parse_expr("\\x -> x x");
3244        let mut ts = TypeSystem::new_with_prelude().unwrap();
3245        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3246        assert!(matches!(err, TypeError::Occurs(_, _)));
3247    }
3248
3249    #[test]
3250    fn infer_apply_constructor_too_many_args_error() {
3251        let expr = parse_expr("Some 1 2");
3252        let mut ts = TypeSystem::new_with_prelude().unwrap();
3253        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3254        assert!(matches!(err, TypeError::Unification(_, _)));
3255    }
3256
3257    #[test]
3258    fn infer_operator_type_mismatch_error() {
3259        let expr = parse_expr("1 + true");
3260        let mut ts = TypeSystem::new_with_prelude().unwrap();
3261        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3262        assert!(matches!(err, TypeError::Unification(_, _)));
3263    }
3264
3265    #[test]
3266    fn infer_non_exhaustive_match_is_error() {
3267        let expr = parse_expr("match (Ok 1) when Ok x -> x");
3268        let mut ts = TypeSystem::new_with_prelude().unwrap();
3269        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3270        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3271    }
3272
3273    #[test]
3274    fn infer_non_exhaustive_match_on_bound_var_error() {
3275        let expr = parse_expr("let x = Ok 1 in match x when Ok y -> y");
3276        let mut ts = TypeSystem::new_with_prelude().unwrap();
3277        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3278        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3279    }
3280
3281    #[test]
3282    fn infer_non_exhaustive_match_in_lambda_error() {
3283        let expr = parse_expr("\\x -> match x when Ok y -> y");
3284        let mut ts = TypeSystem::new_with_prelude().unwrap();
3285        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3286        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3287    }
3288
3289    #[test]
3290    fn infer_non_exhaustive_option_match_error() {
3291        let expr = parse_expr("match (Some 1) when Some x -> x");
3292        let mut ts = TypeSystem::new_with_prelude().unwrap();
3293        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3294        match err {
3295            TypeError::NonExhaustiveMatch { missing, .. } => {
3296                assert_eq!(missing, vec![sym("None")]);
3297            }
3298            other => panic!("expected non-exhaustive match, got {other:?}"),
3299        }
3300    }
3301
3302    #[test]
3303    fn infer_non_exhaustive_result_match_error() {
3304        let expr = parse_expr("match (Err 1) when Ok x -> x");
3305        let mut ts = TypeSystem::new_with_prelude().unwrap();
3306        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3307        match err {
3308            TypeError::NonExhaustiveMatch { missing, .. } => {
3309                assert_eq!(missing, vec![sym("Err")]);
3310            }
3311            other => panic!("expected non-exhaustive match, got {other:?}"),
3312        }
3313    }
3314
3315    #[test]
3316    fn infer_non_exhaustive_list_missing_empty_error() {
3317        let expr = parse_expr("match [1, 2] when x::xs -> x");
3318        let mut ts = TypeSystem::new_with_prelude().unwrap();
3319        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3320        match err {
3321            TypeError::NonExhaustiveMatch { missing, .. } => {
3322                assert_eq!(missing, vec![sym("Empty")]);
3323            }
3324            other => panic!("expected non-exhaustive match, got {other:?}"),
3325        }
3326    }
3327
3328    #[test]
3329    fn infer_non_exhaustive_list_match_on_bound_var_error() {
3330        let expr = parse_expr("let xs = [1, 2] in match xs when x::xs -> x");
3331        let mut ts = TypeSystem::new_with_prelude().unwrap();
3332        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3333        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3334    }
3335
3336    #[test]
3337    fn infer_non_exhaustive_list_missing_cons_error() {
3338        let expr = parse_expr("match [1] when [] -> 0");
3339        let mut ts = TypeSystem::new_with_prelude().unwrap();
3340        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3341        match err {
3342            TypeError::NonExhaustiveMatch { missing, .. } => {
3343                assert_eq!(missing, vec![sym("Cons")]);
3344            }
3345            other => panic!("expected non-exhaustive match, got {other:?}"),
3346        }
3347    }
3348
3349    #[test]
3350    fn infer_match_list_patterns_on_result_error() {
3351        let expr = parse_expr("match (Ok 1) when [] -> 0 when x::xs -> 1");
3352        let mut ts = TypeSystem::new_with_prelude().unwrap();
3353        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3354        assert!(matches!(err, TypeError::Unification(_, _)));
3355    }
3356
3357    #[test]
3358    fn infer_missing_instances_produce_unsatisfied_predicates() {
3359        for (name, code) in [
3360            ("division", "1 / 2"),
3361            ("eq_dict", "{a = 1} == {a = 2}"),
3362            ("min_bool", "min [true]"),
3363            ("map_dict", r#"map (\x -> x) {a = 1}"#),
3364        ] {
3365            let (class, pred_type, expected_ty) = match name {
3366                "division" => (
3367                    "Field",
3368                    Type::builtin(BuiltinTypeId::I32),
3369                    Some(Type::builtin(BuiltinTypeId::I32)),
3370                ),
3371                "eq_dict" => ("Eq", dict_of(Type::builtin(BuiltinTypeId::I32)), None),
3372                "min_bool" => ("Ord", Type::builtin(BuiltinTypeId::Bool), None),
3373                "map_dict" => ("Functor", Type::builtin(BuiltinTypeId::Dict), None),
3374                _ => unreachable!("unknown test case {name}"),
3375            };
3376
3377            let expr = parse_expr(code);
3378            let mut ts = TypeSystem::new_with_prelude().unwrap();
3379            let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3380            if let Some(expected) = expected_ty {
3381                assert_eq!(ty, expected, "{name}");
3382            }
3383
3384            let pred = preds
3385                .iter()
3386                .find(|p| p.class.as_ref() == class && p.typ == pred_type)
3387                .unwrap();
3388            assert!(!entails(&ts.classes, &[], pred).unwrap(), "{name}");
3389        }
3390    }
3391
3392    #[test]
3393    fn record_update_single_variant_adt_infers() {
3394        let program = parse_program(
3395            r#"
3396            type Foo = Bar { x: i32, y: i32 }
3397            let
3398              foo: Foo = Bar { x = 1, y = 2 },
3399              bar = { foo with { x = 3 } }
3400            in
3401              bar
3402            "#,
3403        );
3404        let mut ts = TypeSystem::new_with_prelude().unwrap();
3405        ts.register_decls(&program.decls).unwrap();
3406        let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3407        assert_eq!(typ.to_string(), "Foo");
3408    }
3409
3410    #[test]
3411    fn record_update_unknown_field_errors() {
3412        let program = parse_program(
3413            r#"
3414            type Foo = Bar { x: i32 }
3415            let
3416              foo: Foo = Bar { x = 1 }
3417            in
3418              { foo with { y = 2 } }
3419            "#,
3420        );
3421        let mut ts = TypeSystem::new_with_prelude().unwrap();
3422        ts.register_decls(&program.decls).unwrap();
3423        let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
3424        let err = strip_span(err);
3425        assert!(matches!(err, TypeError::UnknownField { .. }));
3426    }
3427
3428    #[test]
3429    fn record_update_requires_refined_variant_for_sum_types() {
3430        let program = parse_program(
3431            r#"
3432            type Foo = Bar { x: i32 } | Baz { x: i32 }
3433            let
3434              f = \ (foo : Foo) -> { foo with { x = 2 } }
3435            in
3436              f (Bar { x = 1 })
3437            "#,
3438        );
3439        let mut ts = TypeSystem::new_with_prelude().unwrap();
3440        ts.register_decls(&program.decls).unwrap();
3441        let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
3442        let err = strip_span(err);
3443        assert!(matches!(err, TypeError::FieldNotKnown { .. }));
3444    }
3445
3446    #[test]
3447    fn record_update_allowed_after_match_refines_variant() {
3448        let program = parse_program(
3449            r#"
3450            type Foo = Bar { x: i32 } | Baz { x: i32 }
3451            let
3452              f = \ (foo : Foo) ->
3453                match foo
3454                  when Bar {x} -> { foo with { x = x + 1 } }
3455                  when Baz {x} -> { foo with { x = x + 2 } }
3456            in
3457              f (Bar { x = 1 })
3458            "#,
3459        );
3460        let mut ts = TypeSystem::new_with_prelude().unwrap();
3461        ts.register_decls(&program.decls).unwrap();
3462        let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3463        assert_eq!(typ.to_string(), "Foo");
3464    }
3465
3466    #[test]
3467    fn record_update_plain_record_type() {
3468        let program = parse_program(
3469            r#"
3470            let
3471              f = \ (r : { x: i32, y: i32 }) -> { r with { y = 9 } }
3472            in
3473              f { x = 1, y = 2 }
3474            "#,
3475        );
3476        let mut ts = TypeSystem::new_with_prelude().unwrap();
3477        ts.register_decls(&program.decls).unwrap();
3478        let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3479        assert_eq!(typ.to_string(), "{x: i32, y: i32}");
3480    }
3481
3482    #[test]
3483    fn infer_typed_hole_expr_is_hole_kind() {
3484        let expr = parse_expr("?");
3485        let mut ts = TypeSystem::new_with_prelude().unwrap();
3486        let (typed, _preds, _ty) = infer_typed(&mut ts, expr.as_ref()).unwrap();
3487        assert!(
3488            matches!(typed.kind, TypedExprKind::Hole),
3489            "typed={typed:#?}"
3490        );
3491    }
3492
3493    #[test]
3494    fn infer_hole_with_annotation_unifies_to_annotation() {
3495        let expr = parse_expr("let x : i32 = ? in x");
3496        let mut ts = TypeSystem::new_with_prelude().unwrap();
3497        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3498        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3499    }
3500
3501    #[test]
3502    fn infer_hole_in_if_condition_is_bool_constrained() {
3503        let expr = parse_expr("if ? then 1 else 2");
3504        let mut ts = TypeSystem::new_with_prelude().unwrap();
3505        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3506        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3507    }
3508
3509    #[test]
3510    fn infer_hole_in_arithmetic_is_numeric_constrained() {
3511        let expr = parse_expr("? + 1");
3512        let mut ts = TypeSystem::new_with_prelude().unwrap();
3513        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3514        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3515    }
3516
3517    #[test]
3518    fn infer_hole_arithmetic_conflicting_annotation_failure() {
3519        let expr = parse_expr("let x : string = (? + 1) in x");
3520        let mut ts = TypeSystem::new_with_prelude().unwrap();
3521        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3522        assert!(matches!(err, TypeError::Unification(_, _)), "err={err:#?}");
3523    }
3524}