Skip to main content

rexlang_typesystem/
inference.rs

1use super::*;
2
3fn dedup_preds(preds: Vec<Predicate>) -> Vec<Predicate> {
4    let mut seen = HashSet::new();
5    let mut out = Vec::with_capacity(preds.len());
6    for pred in preds {
7        if seen.insert(pred.clone()) {
8            out.push(pred);
9        }
10    }
11    out
12}
13
14fn is_integral_primitive(typ: &Type) -> bool {
15    matches!(
16        typ.as_ref(),
17        TypeKind::Con(TypeConst {
18            builtin_id: Some(
19                BuiltinTypeId::U8
20                    | BuiltinTypeId::U16
21                    | BuiltinTypeId::U32
22                    | BuiltinTypeId::U64
23                    | BuiltinTypeId::I8
24                    | BuiltinTypeId::I16
25                    | BuiltinTypeId::I32
26                    | BuiltinTypeId::I64
27            ),
28            ..
29        })
30    )
31}
32
33fn finalize_infer_for_public_api(
34    mut preds: Vec<Predicate>,
35    mut typ: Type,
36) -> Result<(Vec<Predicate>, Type), TypeError> {
37    let mut subst = Subst::new_sync();
38    for pred in &preds {
39        if pred.class.as_ref() == "Integral"
40            && let TypeKind::Var(tv) = pred.typ.as_ref()
41        {
42            subst = subst.insert(tv.id, Type::builtin(BuiltinTypeId::I32));
43        }
44    }
45
46    if !subst_is_empty(&subst) {
47        preds = dedup_preds(preds.apply(&subst));
48        typ = typ.apply(&subst);
49    }
50
51    for pred in &preds {
52        if pred.class.as_ref() != "Integral" {
53            continue;
54        }
55        if matches!(pred.typ.as_ref(), TypeKind::Var(_)) || is_integral_primitive(&pred.typ) {
56            continue;
57        }
58        return Err(TypeError::Unification("i32".into(), pred.typ.to_string()));
59    }
60
61    Ok((preds, typ))
62}
63
64#[derive(Clone, Debug)]
65struct KnownVariant {
66    adt: Symbol,
67    variant: Symbol,
68}
69
70type KnownVariants = HashMap<Symbol, KnownVariant>;
71
72fn apply_scheme_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> Scheme {
73    let preds = scheme
74        .preds
75        .iter()
76        .map(|pred| Predicate::new(pred.class.clone(), unifier.apply_type(&pred.typ)))
77        .collect();
78    let typ = unifier.apply_type(&scheme.typ);
79    Scheme::new(scheme.vars.clone(), preds, typ)
80}
81
82fn scheme_ftv_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> HashSet<TypeVarId> {
83    let mut ftv = unifier.apply_type(&scheme.typ).ftv();
84    for pred in &scheme.preds {
85        ftv.extend(unifier.apply_type(&pred.typ).ftv());
86    }
87    for var in &scheme.vars {
88        ftv.remove(&var.id);
89    }
90    ftv
91}
92
93fn env_ftv_with_unifier(env: &TypeEnv, unifier: &mut Unifier<'_>) -> HashSet<TypeVarId> {
94    let mut out = HashSet::new();
95    for (_name, schemes) in env.values.iter() {
96        for scheme in schemes {
97            out.extend(scheme_ftv_with_unifier(scheme, unifier));
98        }
99    }
100    out
101}
102
103fn generalize_with_unifier(
104    env: &TypeEnv,
105    preds: Vec<Predicate>,
106    typ: Type,
107    unifier: &mut Unifier<'_>,
108) -> Scheme {
109    let preds: Vec<Predicate> = preds
110        .into_iter()
111        .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
112        .collect();
113    let typ = unifier.apply_type(&typ);
114    let mut vars: Vec<TypeVar> = typ
115        .ftv()
116        .union(&preds.ftv())
117        .copied()
118        .collect::<HashSet<_>>()
119        .difference(&env_ftv_with_unifier(env, unifier))
120        .cloned()
121        .map(|id| TypeVar::new(id, None))
122        .collect();
123    vars.sort_by_key(|v| v.id);
124    Scheme::new(vars, preds, typ)
125}
126
127fn monomorphic_scheme_with_unifier(
128    preds: Vec<Predicate>,
129    typ: Type,
130    unifier: &mut Unifier<'_>,
131) -> Scheme {
132    let preds = dedup_preds(
133        preds
134            .into_iter()
135            .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
136            .collect(),
137    );
138    let typ = unifier.apply_type(&typ);
139    Scheme::new(vec![], preds, typ)
140}
141
142pub fn infer_typed(
143    type_system: &mut TypeSystem,
144    expr: &Expr,
145) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
146    infer_typed_inner(type_system, expr)
147}
148
149pub fn infer_typed_with_gas(
150    type_system: &mut TypeSystem,
151    expr: &Expr,
152    gas: &mut GasMeter,
153) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
154    let known = KnownVariants::new();
155    let mut unifier = Unifier::with_gas(gas, type_system.limits.max_infer_depth);
156    let (preds, t, typed) = infer_expr(
157        &mut unifier,
158        &mut type_system.supply,
159        &type_system.env,
160        &type_system.adts,
161        &known,
162        expr,
163    )
164    .map_err(|err| with_span(expr.span(), err))?;
165    let subst = unifier.into_subst();
166    let mut typed = typed.apply(&subst);
167    let mut preds = dedup_preds(preds.apply(&subst));
168    let mut t = t.apply(&subst);
169    let improve = improve_indexable(&preds)?;
170    if !subst_is_empty(&improve) {
171        typed = typed.apply(&improve);
172        preds = dedup_preds(preds.apply(&improve));
173        t = t.apply(&improve);
174    }
175    type_system.check_predicate_kinds(&preds)?;
176    Ok((typed, preds, t))
177}
178
179fn infer_typed_inner(
180    type_system: &mut TypeSystem,
181    expr: &Expr,
182) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
183    let known = KnownVariants::new();
184    let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
185    let (preds, t, typed) = infer_expr(
186        &mut unifier,
187        &mut type_system.supply,
188        &type_system.env,
189        &type_system.adts,
190        &known,
191        expr,
192    )
193    .map_err(|err| with_span(expr.span(), err))?;
194    let subst = unifier.into_subst();
195    let mut typed = typed.apply(&subst);
196    let mut preds = dedup_preds(preds.apply(&subst));
197    let mut t = t.apply(&subst);
198    let improve = improve_indexable(&preds)?;
199    if !subst_is_empty(&improve) {
200        typed = typed.apply(&improve);
201        preds = dedup_preds(preds.apply(&improve));
202        t = t.apply(&improve);
203    }
204    type_system.check_predicate_kinds(&preds)?;
205    Ok((typed, preds, t))
206}
207
208pub fn infer(
209    type_system: &mut TypeSystem,
210    expr: &Expr,
211) -> Result<(Vec<Predicate>, Type), TypeError> {
212    infer_inner(type_system, expr)
213}
214
215pub fn infer_with_gas(
216    type_system: &mut TypeSystem,
217    expr: &Expr,
218    gas: &mut GasMeter,
219) -> Result<(Vec<Predicate>, Type), TypeError> {
220    let known = KnownVariants::new();
221    let mut unifier = Unifier::with_gas(gas, type_system.limits.max_infer_depth);
222    let (preds, t) = infer_expr_type(
223        &mut unifier,
224        &mut type_system.supply,
225        &type_system.env,
226        &type_system.adts,
227        &known,
228        expr,
229    )
230    .map_err(|err| with_span(expr.span(), err))?;
231    let subst = unifier.into_subst();
232    let preds = dedup_preds(preds.apply(&subst));
233    let t = t.apply(&subst);
234    type_system.check_predicate_kinds(&preds)?;
235    finalize_infer_for_public_api(preds, t)
236}
237
238fn infer_inner(
239    type_system: &mut TypeSystem,
240    expr: &Expr,
241) -> Result<(Vec<Predicate>, Type), TypeError> {
242    let known = KnownVariants::new();
243    let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
244    let (preds, t) = infer_expr_type(
245        &mut unifier,
246        &mut type_system.supply,
247        &type_system.env,
248        &type_system.adts,
249        &known,
250        expr,
251    )
252    .map_err(|err| with_span(expr.span(), err))?;
253    let subst = unifier.into_subst();
254    let mut preds = dedup_preds(preds.apply(&subst));
255    let mut t = t.apply(&subst);
256    let improve = improve_indexable(&preds)?;
257    if !subst_is_empty(&improve) {
258        preds = dedup_preds(preds.apply(&improve));
259        t = t.apply(&improve);
260    }
261    type_system.check_predicate_kinds(&preds)?;
262    finalize_infer_for_public_api(preds, t)
263}
264
265fn improve_indexable(preds: &[Predicate]) -> Result<Subst, TypeError> {
266    let mut subst = Subst::new_sync();
267    loop {
268        let mut changed = false;
269        for pred in preds {
270            let pred = pred.apply(&subst);
271            if pred.class.as_ref() != "Indexable" {
272                continue;
273            }
274            let TypeKind::Tuple(parts) = pred.typ.as_ref() else {
275                continue;
276            };
277            if parts.len() != 2 {
278                continue;
279            }
280            let container = parts[0].clone();
281            let elem = parts[1].clone();
282            let s = indexable_elem_subst(&container, &elem)?;
283            if !subst_is_empty(&s) {
284                subst = compose_subst(s, subst);
285                changed = true;
286            }
287        }
288        if !changed {
289            break;
290        }
291    }
292    Ok(subst)
293}
294
295fn indexable_elem_subst(container: &Type, elem: &Type) -> Result<Subst, TypeError> {
296    match container.as_ref() {
297        TypeKind::App(head, arg) => match head.as_ref() {
298            TypeKind::Con(tc)
299                if matches!(
300                    tc.builtin_id,
301                    Some(BuiltinTypeId::List | BuiltinTypeId::Array)
302                ) =>
303            {
304                unify(elem, arg)
305            }
306            _ => Ok(Subst::new_sync()),
307        },
308        TypeKind::Tuple(elems) => {
309            if elems.is_empty() {
310                return Ok(Subst::new_sync());
311            }
312            let mut subst = Subst::new_sync();
313            let mut cur = elems[0].clone();
314            for ty in elems.iter().skip(1) {
315                let s_next = unify(&cur.apply(&subst), &ty.apply(&subst))?;
316                subst = compose_subst(s_next, subst);
317                cur = cur.apply(&subst);
318            }
319            let elem = elem.apply(&subst);
320            let s_elem = unify(&elem, &cur.apply(&subst))?;
321            Ok(compose_subst(s_elem, subst))
322        }
323        _ => Ok(Subst::new_sync()),
324    }
325}
326
327type LambdaChain<'a> = (
328    Vec<(Symbol, Option<TypeExpr>)>,
329    Vec<TypeConstraint>,
330    &'a Expr,
331);
332
333fn collect_lambda_chain<'a>(expr: &'a Expr) -> LambdaChain<'a> {
334    let mut params = Vec::new();
335    let mut constraints = Vec::new();
336    let mut cur = expr;
337    let mut seen_constraints = false;
338    while let Expr::Lam(_, _scope, param, ann, lam_constraints, body) = cur {
339        if !lam_constraints.is_empty() {
340            if seen_constraints {
341                break;
342            }
343            constraints = lam_constraints.clone();
344            seen_constraints = true;
345        }
346        params.push((param.name.clone(), ann.clone()));
347        cur = body.as_ref();
348    }
349    (params, constraints, cur)
350}
351
352fn collect_app_chain(expr: &Expr) -> (&Expr, Vec<&Expr>) {
353    let mut args = Vec::new();
354    let mut cur = expr;
355    while let Expr::App(_, f, x) = cur {
356        args.push(x.as_ref());
357        cur = f.as_ref();
358    }
359    args.reverse();
360    (cur, args)
361}
362
363fn narrow_overload_candidates(candidates: &[Type], arg_ty: &Type) -> Vec<Type> {
364    let mut out = Vec::new();
365    for candidate in candidates {
366        let Some((params, ret)) = decompose_fun(candidate, 1) else {
367            continue;
368        };
369        let param = &params[0];
370        if let Ok(s) = unify(param, arg_ty) {
371            out.push(ret.apply(&s));
372        }
373    }
374    out
375}
376
377fn unary_app_arg(typ: &Type, ctor_name: &str) -> Option<Type> {
378    let TypeKind::App(head, arg) = typ.as_ref() else {
379        return None;
380    };
381    let TypeKind::Con(tc) = head.as_ref() else {
382        return None;
383    };
384    (tc.name.as_ref() == ctor_name && tc.arity == 1).then(|| arg.clone())
385}
386
387fn infer_app_arg_type(
388    unifier: &mut Unifier<'_>,
389    supply: &mut TypeVarSupply,
390    env: &TypeEnv,
391    adts: &HashMap<Symbol, AdtDecl>,
392    known: &KnownVariants,
393    arg_hint: Option<Type>,
394    arg: &Expr,
395) -> Result<(Vec<Predicate>, Type), TypeError> {
396    match (arg_hint, arg) {
397        (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
398            infer_record_update_type_with_hint(
399                unifier,
400                supply,
401                env,
402                adts,
403                known,
404                base.as_ref(),
405                updates,
406                &arg_hint,
407            )
408        }
409        (Some(arg_hint), Expr::Dict(_, kvs))
410            if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
411        {
412            let TypeKind::Record(fields) = arg_hint.as_ref() else {
413                unreachable!("guarded by matches!")
414            };
415            let expected: HashMap<_, _> =
416                fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
417            let mut seen = HashSet::new();
418            let mut preds = Vec::new();
419            for (k, v) in kvs {
420                let expected_ty = expected
421                    .get(k)
422                    .ok_or_else(|| TypeError::UnknownField {
423                        field: k.clone(),
424                        typ: Type::record(fields.clone()).to_string(),
425                    })?
426                    .clone();
427                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
428                unifier.unify(&t1, &expected_ty)?;
429                preds.extend(p1);
430                seen.insert(k.clone());
431            }
432            for key in expected.keys() {
433                if !seen.contains(key.as_ref()) {
434                    return Err(TypeError::UnknownField {
435                        field: key.clone(),
436                        typ: Type::record(fields.clone()).to_string(),
437                    });
438                }
439            }
440            let record_ty = Type::record(
441                fields
442                    .iter()
443                    .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
444                    .collect(),
445            );
446            Ok((preds, record_ty))
447        }
448        _ => infer_expr_type(unifier, supply, env, adts, known, arg),
449    }
450}
451
452fn infer_app_arg_typed(
453    unifier: &mut Unifier<'_>,
454    supply: &mut TypeVarSupply,
455    env: &TypeEnv,
456    adts: &HashMap<Symbol, AdtDecl>,
457    known: &KnownVariants,
458    arg_hint: Option<Type>,
459    arg: &Expr,
460) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
461    match (arg_hint, arg) {
462        (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
463            infer_record_update_typed_with_hint(
464                unifier,
465                supply,
466                env,
467                adts,
468                known,
469                base.as_ref(),
470                updates,
471                &arg_hint,
472            )
473        }
474        (Some(arg_hint), Expr::Dict(_, kvs))
475            if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
476        {
477            let TypeKind::Record(fields) = arg_hint.as_ref() else {
478                unreachable!("guarded by matches!")
479            };
480            let mut preds = Vec::new();
481            let mut typed_kvs = BTreeMap::new();
482            let expected: HashMap<_, _> =
483                fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
484            for (k, v) in kvs {
485                let expected_ty = expected
486                    .get(k)
487                    .ok_or_else(|| TypeError::UnknownField {
488                        field: k.clone(),
489                        typ: Type::record(fields.clone()).to_string(),
490                    })?
491                    .clone();
492                let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
493                unifier.unify(&t1, &expected_ty)?;
494                preds.extend(p1);
495                typed_kvs.insert(k.clone(), typed_v);
496            }
497            for key in expected.keys() {
498                if !typed_kvs.contains_key(key.as_ref()) {
499                    return Err(TypeError::UnknownField {
500                        field: key.clone(),
501                        typ: Type::record(fields.clone()).to_string(),
502                    });
503                }
504            }
505            let record_ty = Type::record(
506                fields
507                    .iter()
508                    .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
509                    .collect(),
510            );
511            let typed = TypedExpr::new(record_ty.clone(), TypedExprKind::Dict(typed_kvs));
512            Ok((preds, record_ty, typed))
513        }
514        _ => infer_expr(unifier, supply, env, adts, known, arg),
515    }
516}
517
518#[allow(clippy::too_many_arguments)]
519fn infer_record_update_type_with_hint(
520    unifier: &mut Unifier<'_>,
521    supply: &mut TypeVarSupply,
522    env: &TypeEnv,
523    adts: &HashMap<Symbol, AdtDecl>,
524    known: &KnownVariants,
525    base: &Expr,
526    updates: &BTreeMap<Symbol, Arc<Expr>>,
527    hint_ty: &Type,
528) -> Result<(Vec<Predicate>, Type), TypeError> {
529    let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
530    unifier.unify(&t_base, hint_ty)?;
531    let base_ty = unifier.apply_type(&t_base);
532    let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
533    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
534    let (result_ty, fields) = resolve_record_update(
535        unifier,
536        supply,
537        adts,
538        &base_ty,
539        known_variant,
540        &update_fields,
541    )?;
542    let expected: HashMap<_, _> = fields.into_iter().collect();
543
544    let mut preds = p_base;
545    for (k, v) in updates {
546        let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
547            field: k.clone(),
548            typ: result_ty.to_string(),
549        })?;
550        let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
551        unifier.unify(&t1, expected_ty)?;
552        preds.extend(p1);
553    }
554    Ok((preds, result_ty))
555}
556
557#[allow(clippy::too_many_arguments)]
558fn infer_record_update_typed_with_hint(
559    unifier: &mut Unifier<'_>,
560    supply: &mut TypeVarSupply,
561    env: &TypeEnv,
562    adts: &HashMap<Symbol, AdtDecl>,
563    known: &KnownVariants,
564    base: &Expr,
565    updates: &BTreeMap<Symbol, Arc<Expr>>,
566    hint_ty: &Type,
567) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
568    let (p_base, t_base, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
569    unifier.unify(&t_base, hint_ty)?;
570    let base_ty = unifier.apply_type(&t_base);
571    let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
572    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
573    let (result_ty, fields) = resolve_record_update(
574        unifier,
575        supply,
576        adts,
577        &base_ty,
578        known_variant,
579        &update_fields,
580    )?;
581    let expected: HashMap<_, _> = fields.into_iter().collect();
582
583    let mut preds = p_base;
584    let mut typed_updates = BTreeMap::new();
585    for (k, v) in updates {
586        let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
587            field: k.clone(),
588            typ: result_ty.to_string(),
589        })?;
590        let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
591        unifier.unify(&t1, expected_ty)?;
592        preds.extend(p1);
593        typed_updates.insert(k.clone(), typed_v);
594    }
595
596    let typed = TypedExpr::new(
597        result_ty.clone(),
598        TypedExprKind::RecordUpdate {
599            base: Box::new(typed_base),
600            updates: typed_updates,
601        },
602    );
603    Ok((preds, result_ty, typed))
604}
605
606fn infer_expr_type(
607    unifier: &mut Unifier<'_>,
608    supply: &mut TypeVarSupply,
609    env: &TypeEnv,
610    adts: &HashMap<Symbol, AdtDecl>,
611    known: &KnownVariants,
612    expr: &Expr,
613) -> Result<(Vec<Predicate>, Type), TypeError> {
614    let span = *expr.span();
615    let res = unifier.with_infer_depth(span, |unifier| {
616        infer_expr_type_inner(unifier, supply, env, adts, known, expr)
617    });
618    res.map_err(|err| with_span(&span, err))
619}
620
621fn infer_expr_type_inner(
622    unifier: &mut Unifier<'_>,
623    supply: &mut TypeVarSupply,
624    env: &TypeEnv,
625    adts: &HashMap<Symbol, AdtDecl>,
626    known: &KnownVariants,
627    expr: &Expr,
628) -> Result<(Vec<Predicate>, Type), TypeError> {
629    unifier.charge_infer_node()?;
630    match expr {
631        Expr::Bool(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Bool))),
632        Expr::Uint(_, _) => {
633            let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
634            Ok((vec![Predicate::new("Integral", lit_ty.clone())], lit_ty))
635        }
636        Expr::Int(_, _) => {
637            let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
638            Ok((
639                vec![
640                    Predicate::new("Integral", lit_ty.clone()),
641                    Predicate::new("AdditiveGroup", lit_ty.clone()),
642                ],
643                lit_ty,
644            ))
645        }
646        Expr::Float(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::F32))),
647        Expr::String(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::String))),
648        Expr::Uuid(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Uuid))),
649        Expr::DateTime(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::DateTime))),
650        Expr::Hole(_) => {
651            let t = Type::var(supply.fresh(Some(sym("hole"))));
652            Ok((vec![], t))
653        }
654        Expr::Var(var) => {
655            let schemes = env
656                .lookup(&var.name)
657                .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
658            if schemes.len() == 1 {
659                let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
660                let (preds, t) = instantiate(&scheme, supply);
661                Ok((preds, t))
662            } else {
663                for scheme in schemes {
664                    if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
665                        return Err(TypeError::AmbiguousOverload(var.name.clone()));
666                    }
667                }
668                let t = Type::var(supply.fresh(Some(var.name.clone())));
669                Ok((vec![], t))
670            }
671        }
672        Expr::Lam(..) => {
673            let (params, constraints, body) = collect_lambda_chain(expr);
674            let mut ann_vars = HashMap::new();
675            let mut param_tys = Vec::with_capacity(params.len());
676            for (name, ann) in &params {
677                let param_ty = match ann {
678                    Some(ann) => type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?,
679                    None => Type::var(supply.fresh(Some(name.clone()))),
680                };
681                param_tys.push((name.clone(), param_ty));
682            }
683
684            let mut env1 = env.clone();
685            let mut known_body = known.clone();
686            for (name, param_ty) in &param_tys {
687                env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
688                known_body.remove(name);
689            }
690
691            let (mut preds, body_ty) =
692                infer_expr_type(unifier, supply, &env1, adts, &known_body, body)?;
693            let constraint_preds =
694                predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
695            preds.extend(constraint_preds);
696
697            let mut fun_ty = unifier.apply_type(&body_ty);
698            for (_, param_ty) in param_tys.iter().rev() {
699                fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
700            }
701            Ok((preds, fun_ty))
702        }
703        Expr::App(..) => {
704            let (head, args) = collect_app_chain(expr);
705            let (mut preds, mut func_ty) =
706                infer_expr_type(unifier, supply, env, adts, known, head)?;
707            let mut overload_name = None;
708            let mut overload_candidates = if let Expr::Var(var) = head {
709                if let Some(schemes) = env.lookup(&var.name) {
710                    if schemes.len() <= 1 {
711                        None
712                    } else {
713                        let mut candidates = Vec::new();
714                        for scheme in schemes {
715                            if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
716                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
717                            }
718                            let scheme = apply_scheme_with_unifier(scheme, unifier);
719                            let (p, typ) = instantiate(&scheme, supply);
720                            if !p.is_empty() {
721                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
722                            }
723                            candidates.push(typ);
724                        }
725                        overload_name = Some(var.name.clone());
726                        Some(candidates)
727                    }
728                } else {
729                    None
730                }
731            } else {
732                None
733            };
734            for arg in args {
735                let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
736                    TypeKind::Fun(arg, _) => Some(arg.clone()),
737                    _ => None,
738                };
739                let (p_arg, arg_ty) =
740                    infer_app_arg_type(unifier, supply, env, adts, known, arg_hint, arg)?;
741                let arg_ty = unifier.apply_type(&arg_ty);
742                if let Some(candidates) = overload_candidates.take() {
743                    let candidates = candidates
744                        .into_iter()
745                        .map(|t| unifier.apply_type(&t))
746                        .collect::<Vec<_>>();
747                    let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
748                    if narrowed.is_empty()
749                        && let Some(name) = &overload_name
750                    {
751                        return Err(TypeError::AmbiguousOverload(name.clone()));
752                    }
753                    overload_candidates = Some(narrowed);
754                }
755                let res_ty = match overload_candidates.as_ref() {
756                    Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
757                    _ => Type::var(supply.fresh(Some("r".into()))),
758                };
759                unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
760                preds.extend(p_arg);
761                func_ty = match overload_candidates.as_ref() {
762                    Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
763                    _ => unifier.apply_type(&res_ty),
764                };
765            }
766            Ok((preds, func_ty))
767        }
768        Expr::Project(_, base, field) => {
769            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, base)?;
770            let base_ty = unifier.apply_type(&t1);
771            let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
772            let field_ty =
773                resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
774            Ok((p1, field_ty))
775        }
776        Expr::RecordUpdate(_, base, updates) => {
777            let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
778            let base_ty = unifier.apply_type(&t_base);
779            let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
780            let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
781            let (result_ty, fields) = resolve_record_update(
782                unifier,
783                supply,
784                adts,
785                &base_ty,
786                known_variant,
787                &update_fields,
788            )?;
789            let expected: HashMap<_, _> = fields.into_iter().collect();
790
791            let mut preds = p_base;
792            for (k, v) in updates {
793                let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
794                    field: k.clone(),
795                    typ: result_ty.to_string(),
796                })?;
797                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
798                unifier.unify(&t1, expected_ty)?;
799                preds.extend(p1);
800            }
801            Ok((preds, result_ty))
802        }
803        Expr::Let(..) => {
804            let mut bindings = Vec::new();
805            let mut cur = expr;
806            while let Expr::Let(_, v, ann, d, b) = cur {
807                bindings.push((v.clone(), ann.clone(), d.clone()));
808                cur = b.as_ref();
809            }
810
811            let mut env_cur = env.clone();
812            let mut known_cur = known.clone();
813            for (v, ann, d) in bindings {
814                let (p1, t1) = if let Some(ref ann_expr) = ann {
815                    let mut ann_vars = HashMap::new();
816                    let ann_ty =
817                        type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
818                    match d.as_ref() {
819                        Expr::RecordUpdate(_, base, updates) => infer_record_update_type_with_hint(
820                            unifier,
821                            supply,
822                            &env_cur,
823                            adts,
824                            &known_cur,
825                            base.as_ref(),
826                            updates,
827                            &ann_ty,
828                        )?,
829                        _ => {
830                            let (p1, t1) =
831                                infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?;
832                            unifier.unify(&t1, &ann_ty)?;
833                            (p1, t1)
834                        }
835                    }
836                } else {
837                    infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?
838                };
839                let def_ty = unifier.apply_type(&t1);
840                let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
841                    monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
842                } else {
843                    let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
844                    reject_ambiguous_scheme(&scheme)?;
845                    scheme
846                };
847                env_cur.extend(v.name.clone(), scheme);
848                if let Some(known_variant) =
849                    known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
850                {
851                    known_cur.insert(
852                        v.name.clone(),
853                        KnownVariant {
854                            adt: known_variant.adt,
855                            variant: known_variant.variant,
856                        },
857                    );
858                } else {
859                    known_cur.remove(&v.name);
860                }
861            }
862
863            let (p_body, t_body) =
864                infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, cur)?;
865            Ok((p_body, t_body))
866        }
867        Expr::LetRec(_, bindings, body) => {
868            let mut env_seed = env.clone();
869            let mut known_seed = known.clone();
870            let mut binding_tys = HashMap::new();
871            for (var, _ann, _def) in bindings {
872                let tv = Type::var(supply.fresh(Some(var.name.clone())));
873                env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
874                known_seed.remove(&var.name);
875                binding_tys.insert(var.name.clone(), tv);
876            }
877
878            let mut inferred = Vec::with_capacity(bindings.len());
879            for (var, ann, def) in bindings {
880                let (preds, def_ty) =
881                    infer_expr_type(unifier, supply, &env_seed, adts, &known_seed, def)?;
882                if let Some(ann) = ann {
883                    let mut ann_vars = HashMap::new();
884                    let ann_ty = type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
885                    unifier.unify(&def_ty, &ann_ty)?;
886                }
887                let binding_ty = binding_tys
888                    .get(&var.name)
889                    .cloned()
890                    .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
891                unifier.unify(&binding_ty, &def_ty)?;
892                let resolved_ty = unifier.apply_type(&binding_ty);
893
894                if let Some(known_variant) =
895                    known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
896                {
897                    known_seed.insert(
898                        var.name.clone(),
899                        KnownVariant {
900                            adt: known_variant.adt,
901                            variant: known_variant.variant,
902                        },
903                    );
904                } else {
905                    known_seed.remove(&var.name);
906                }
907                inferred.push((var.name.clone(), preds, resolved_ty));
908            }
909
910            let mut env_body = env.clone();
911            for (name, preds, def_ty) in inferred {
912                let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
913                reject_ambiguous_scheme(&scheme)?;
914                env_body.extend(name, scheme);
915            }
916
917            let (p_body, t_body) =
918                infer_expr_type(unifier, supply, &env_body, adts, &known_seed, body)?;
919            Ok((p_body, t_body))
920        }
921        Expr::Ite(_, cond, then_expr, else_expr) => {
922            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, cond)?;
923            unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
924            let (p2, t2) = infer_expr_type(unifier, supply, env, adts, known, then_expr)?;
925            let (p3, t3) = infer_expr_type(unifier, supply, env, adts, known, else_expr)?;
926            unifier.unify(&t2, &t3)?;
927            let out_ty = unifier.apply_type(&t2);
928            let mut preds = p1;
929            preds.extend(p2);
930            preds.extend(p3);
931            Ok((preds, out_ty))
932        }
933        Expr::Tuple(_, elems) => {
934            let mut preds = Vec::new();
935            let mut types = Vec::new();
936            for elem in elems {
937                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
938                preds.extend(p1);
939                types.push(unifier.apply_type(&t1));
940            }
941            let tuple_ty = Type::tuple(types);
942            Ok((preds, tuple_ty))
943        }
944        Expr::List(_, elems) => {
945            let elem_tv = Type::var(supply.fresh(Some("a".into())));
946            let mut preds = Vec::new();
947            for elem in elems {
948                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
949                unifier.unify(&t1, &elem_tv)?;
950                preds.extend(p1);
951            }
952            let list_ty = Type::app(
953                Type::builtin(BuiltinTypeId::List),
954                unifier.apply_type(&elem_tv),
955            );
956            Ok((preds, list_ty))
957        }
958        Expr::Dict(_, kvs) => {
959            let elem_tv = Type::var(supply.fresh(Some("v".into())));
960            let mut preds = Vec::new();
961            for v in kvs.values() {
962                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
963                unifier.unify(&t1, &elem_tv)?;
964                preds.extend(p1);
965            }
966            let dict_ty = Type::app(
967                Type::builtin(BuiltinTypeId::Dict),
968                unifier.apply_type(&elem_tv),
969            );
970            Ok((preds, dict_ty))
971        }
972        Expr::Match(_, scrutinee, arms) => {
973            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, scrutinee.as_ref())?;
974            let mut preds = p1;
975            let res_ty = Type::var(supply.fresh(Some("match".into())));
976            let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
977
978            for (pat, expr) in arms {
979                let scrutinee_ty = unifier.apply_type(&t1);
980                let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
981                preds.extend(p_pat);
982
983                let mut env_arm = env.clone();
984                for (name, ty) in binds {
985                    env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
986                }
987                let mut known_arm = known.clone();
988                if let Expr::Var(var) = scrutinee.as_ref() {
989                    match pat {
990                        Pattern::Named(_, name, _) => {
991                            let name_sym = name.to_dotted_symbol();
992                            if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
993                                known_arm.insert(
994                                    var.name.clone(),
995                                    KnownVariant {
996                                        adt: adt.name.clone(),
997                                        variant: name_sym,
998                                    },
999                                );
1000                            } else {
1001                                known_arm.remove(&var.name);
1002                            }
1003                        }
1004                        _ => {
1005                            known_arm.remove(&var.name);
1006                        }
1007                    }
1008                }
1009                let (p_expr, t_expr) =
1010                    infer_expr_type(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1011                unifier.unify(&res_ty, &t_expr)?;
1012                preds.extend(p_expr);
1013            }
1014
1015            let scrutinee_ty = unifier.apply_type(&t1);
1016            check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1017            let out_ty = unifier.apply_type(&res_ty);
1018            Ok((preds, out_ty))
1019        }
1020        Expr::Ann(_, expr, ann) => {
1021            let ann_ty = type_from_annotation_expr(adts, ann)?;
1022            match expr.as_ref() {
1023                Expr::RecordUpdate(_, base, updates) => {
1024                    let (preds, out_ty) = infer_record_update_type_with_hint(
1025                        unifier,
1026                        supply,
1027                        env,
1028                        adts,
1029                        known,
1030                        base.as_ref(),
1031                        updates,
1032                        &ann_ty,
1033                    )?;
1034                    Ok((preds, out_ty))
1035                }
1036                _ => {
1037                    let (preds, expr_ty) =
1038                        infer_expr_type(unifier, supply, env, adts, known, expr)?;
1039                    unifier.unify(&expr_ty, &ann_ty)?;
1040                    let out_ty = unifier.apply_type(&ann_ty);
1041                    Ok((preds, out_ty))
1042                }
1043            }
1044        }
1045    }
1046}
1047
1048fn infer_expr(
1049    unifier: &mut Unifier<'_>,
1050    supply: &mut TypeVarSupply,
1051    env: &TypeEnv,
1052    adts: &HashMap<Symbol, AdtDecl>,
1053    known: &KnownVariants,
1054    expr: &Expr,
1055) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
1056    let span = *expr.span();
1057    let res = unifier.with_infer_depth(span, |unifier| {
1058        (|| {
1059            unifier.charge_infer_node()?;
1060            match expr {
1061                Expr::Bool(_, v) => {
1062                    let t = Type::builtin(BuiltinTypeId::Bool);
1063                    Ok((
1064                        vec![],
1065                        t.clone(),
1066                        TypedExpr::new(t, TypedExprKind::Bool(*v)),
1067                    ))
1068                }
1069                Expr::Uint(_, v) => {
1070                    let t = Type::var(supply.fresh(Some(sym("n"))));
1071                    Ok((
1072                        vec![Predicate::new("Integral", t.clone())],
1073                        t.clone(),
1074                        TypedExpr::new(t, TypedExprKind::Uint(*v)),
1075                    ))
1076                }
1077                Expr::Int(_, v) => {
1078                    let t = Type::var(supply.fresh(Some(sym("n"))));
1079                    Ok((
1080                        vec![
1081                            Predicate::new("Integral", t.clone()),
1082                            Predicate::new("AdditiveGroup", t.clone()),
1083                        ],
1084                        t.clone(),
1085                        TypedExpr::new(t, TypedExprKind::Int(*v)),
1086                    ))
1087                }
1088                Expr::Float(_, v) => {
1089                    let t = Type::builtin(BuiltinTypeId::F32);
1090                    Ok((
1091                        vec![],
1092                        t.clone(),
1093                        TypedExpr::new(t, TypedExprKind::Float(*v)),
1094                    ))
1095                }
1096                Expr::String(_, v) => {
1097                    let t = Type::builtin(BuiltinTypeId::String);
1098                    Ok((
1099                        vec![],
1100                        t.clone(),
1101                        TypedExpr::new(t, TypedExprKind::String(v.clone())),
1102                    ))
1103                }
1104                Expr::Uuid(_, v) => {
1105                    let t = Type::builtin(BuiltinTypeId::Uuid);
1106                    Ok((
1107                        vec![],
1108                        t.clone(),
1109                        TypedExpr::new(t, TypedExprKind::Uuid(*v)),
1110                    ))
1111                }
1112                Expr::DateTime(_, v) => {
1113                    let t = Type::builtin(BuiltinTypeId::DateTime);
1114                    Ok((
1115                        vec![],
1116                        t.clone(),
1117                        TypedExpr::new(t, TypedExprKind::DateTime(*v)),
1118                    ))
1119                }
1120                Expr::Hole(_) => {
1121                    let t = Type::var(supply.fresh(Some(sym("hole"))));
1122                    Ok((vec![], t.clone(), TypedExpr::new(t, TypedExprKind::Hole)))
1123                }
1124                Expr::Var(var) => {
1125                    let schemes = env
1126                        .lookup(&var.name)
1127                        .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1128                    if schemes.len() == 1 {
1129                        let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1130                        let (preds, t) = instantiate(&scheme, supply);
1131                        let typed = TypedExpr::new(
1132                            t.clone(),
1133                            TypedExprKind::Var {
1134                                name: var.name.clone(),
1135                                overloads: vec![],
1136                            },
1137                        );
1138                        Ok((preds, t, typed))
1139                    } else {
1140                        let mut overloads = Vec::new();
1141                        for scheme in schemes {
1142                            if !scheme.preds.is_empty() {
1143                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
1144                            }
1145
1146                            let scheme = apply_scheme_with_unifier(scheme, unifier);
1147                            let (preds, typ) = instantiate(&scheme, supply);
1148                            if !preds.is_empty() {
1149                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
1150                            }
1151                            overloads.push(typ);
1152                        }
1153                        let t = Type::var(supply.fresh(Some(var.name.clone())));
1154                        let typed = TypedExpr::new(
1155                            t.clone(),
1156                            TypedExprKind::Var {
1157                                name: var.name.clone(),
1158                                overloads,
1159                            },
1160                        );
1161                        Ok((vec![], t, typed))
1162                    }
1163                }
1164                Expr::Lam(..) => {
1165                    let (params, constraints, body) = collect_lambda_chain(expr);
1166                    let mut ann_vars = HashMap::new();
1167                    let mut param_tys = Vec::with_capacity(params.len());
1168                    for (name, ann) in &params {
1169                        let param_ty = match ann {
1170                            Some(ann) => {
1171                                type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?
1172                            }
1173                            None => Type::var(supply.fresh(Some(name.clone()))),
1174                        };
1175                        param_tys.push((name.clone(), param_ty));
1176                    }
1177
1178                    let mut env1 = env.clone();
1179                    let mut known_body = known.clone();
1180                    for (name, param_ty) in &param_tys {
1181                        env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
1182                        known_body.remove(name);
1183                    }
1184
1185                    let (mut preds, body_ty, typed_body) =
1186                        infer_expr(unifier, supply, &env1, adts, &known_body, body)?;
1187                    let constraint_preds =
1188                        predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
1189                    preds.extend(constraint_preds);
1190
1191                    let mut typed = typed_body;
1192                    let mut fun_ty = unifier.apply_type(&body_ty);
1193                    for (name, param_ty) in param_tys.iter().rev() {
1194                        fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
1195                        typed = TypedExpr::new(
1196                            fun_ty.clone(),
1197                            TypedExprKind::Lam {
1198                                param: name.clone(),
1199                                body: Box::new(typed),
1200                            },
1201                        );
1202                    }
1203
1204                    Ok((preds, fun_ty, typed))
1205                }
1206                Expr::App(..) => {
1207                    let (head, args) = collect_app_chain(expr);
1208                    let (mut preds, mut func_ty, mut typed) =
1209                        infer_expr(unifier, supply, env, adts, known, head)?;
1210                    let mut overload_name = None;
1211                    let mut overload_candidates = match &typed.kind {
1212                        TypedExprKind::Var { name, overloads } if !overloads.is_empty() => {
1213                            overload_name = Some(name.clone());
1214                            Some(overloads.clone())
1215                        }
1216                        _ => None,
1217                    };
1218                    for arg in args {
1219                        let expected_arg = match unifier.apply_type(&func_ty).as_ref() {
1220                            TypeKind::Fun(arg, _) => Some(arg.clone()),
1221                            _ => None,
1222                        };
1223                        let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
1224                            TypeKind::Fun(arg, _) => Some(arg.clone()),
1225                            _ => None,
1226                        };
1227                        let (p_arg, arg_ty, typed_arg) =
1228                            infer_app_arg_typed(unifier, supply, env, adts, known, arg_hint, arg)?;
1229                        let mut arg_ty = unifier.apply_type(&arg_ty);
1230                        let mut typed_arg = typed_arg;
1231
1232                        if let Some(expected_arg) = expected_arg {
1233                            let expected_arg = unifier.apply_type(&expected_arg);
1234                            if let (Some(expected_elem), Some(arg_elem)) = (
1235                                unary_app_arg(&expected_arg, "Array"),
1236                                unary_app_arg(&arg_ty, "List"),
1237                            ) {
1238                                unifier.unify(&expected_elem, &arg_elem)?;
1239                                let elem_ty = unifier.apply_type(&expected_elem);
1240                                let list_ty = Type::list(elem_ty.clone());
1241                                let array_ty = Type::array(elem_ty);
1242                                let coercion_ty = Type::fun(list_ty, array_ty.clone());
1243                                let coercion_fn = TypedExpr::new(
1244                                    coercion_ty,
1245                                    TypedExprKind::Var {
1246                                        name: sym("prim_array_from_list"),
1247                                        overloads: vec![],
1248                                    },
1249                                );
1250                                typed_arg = TypedExpr::new(
1251                                    array_ty.clone(),
1252                                    TypedExprKind::App(Box::new(coercion_fn), Box::new(typed_arg)),
1253                                );
1254                                arg_ty = array_ty;
1255                            }
1256                        }
1257                        if let Some(candidates) = overload_candidates.take() {
1258                            let candidates = candidates
1259                                .into_iter()
1260                                .map(|t| unifier.apply_type(&t))
1261                                .collect::<Vec<_>>();
1262                            let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
1263                            if narrowed.is_empty()
1264                                && let Some(name) = &overload_name
1265                            {
1266                                return Err(TypeError::AmbiguousOverload(name.clone()));
1267                            }
1268                            overload_candidates = Some(narrowed);
1269                        }
1270                        let res_ty = match overload_candidates.as_ref() {
1271                            Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
1272                            _ => Type::var(supply.fresh(Some("r".into()))),
1273                        };
1274                        unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
1275                        let result_ty = match overload_candidates.as_ref() {
1276                            Some(candidates) if candidates.len() == 1 => {
1277                                unifier.apply_type(&candidates[0])
1278                            }
1279                            _ => unifier.apply_type(&res_ty),
1280                        };
1281                        preds.extend(p_arg);
1282                        typed = TypedExpr::new(
1283                            result_ty.clone(),
1284                            TypedExprKind::App(Box::new(typed), Box::new(typed_arg)),
1285                        );
1286                        func_ty = result_ty;
1287                    }
1288                    Ok((preds, func_ty, typed))
1289                }
1290                Expr::Project(_, base, field) => {
1291                    let (p1, t1, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
1292                    let base_ty = unifier.apply_type(&t1);
1293                    let known_variant =
1294                        known_variant_from_expr_with_known(base, &base_ty, adts, known);
1295                    let field_ty =
1296                        resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
1297                    let typed = TypedExpr::new(
1298                        field_ty.clone(),
1299                        TypedExprKind::Project {
1300                            expr: Box::new(typed_base),
1301                            field: field.clone(),
1302                        },
1303                    );
1304                    Ok((p1, field_ty, typed))
1305                }
1306                Expr::RecordUpdate(_, base, updates) => {
1307                    let (p_base, t_base, typed_base) =
1308                        infer_expr(unifier, supply, env, adts, known, base)?;
1309                    let base_ty = unifier.apply_type(&t_base);
1310                    let known_variant =
1311                        known_variant_from_expr_with_known(base, &base_ty, adts, known);
1312                    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
1313                    let (result_ty, fields) = resolve_record_update(
1314                        unifier,
1315                        supply,
1316                        adts,
1317                        &base_ty,
1318                        known_variant,
1319                        &update_fields,
1320                    )?;
1321                    let expected: HashMap<_, _> = fields.into_iter().collect();
1322
1323                    let mut preds = p_base;
1324                    let mut typed_updates = BTreeMap::new();
1325                    for (k, v) in updates {
1326                        let expected_ty =
1327                            expected.get(k).ok_or_else(|| TypeError::UnknownField {
1328                                field: k.clone(),
1329                                typ: result_ty.to_string(),
1330                            })?;
1331                        let (p1, t1, typed_v) =
1332                            infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
1333                        unifier.unify(&t1, expected_ty)?;
1334                        preds.extend(p1);
1335                        typed_updates.insert(k.clone(), typed_v);
1336                    }
1337                    let typed = TypedExpr::new(
1338                        result_ty.clone(),
1339                        TypedExprKind::RecordUpdate {
1340                            base: Box::new(typed_base),
1341                            updates: typed_updates,
1342                        },
1343                    );
1344                    Ok((preds, result_ty, typed))
1345                }
1346                Expr::Let(..) => {
1347                    let mut bindings = Vec::new();
1348                    let mut cur = expr;
1349                    while let Expr::Let(_, v, ann, d, b) = cur {
1350                        bindings.push((v.clone(), ann.clone(), d.clone()));
1351                        cur = b.as_ref();
1352                    }
1353
1354                    let mut env_cur = env.clone();
1355                    let mut known_cur = known.clone();
1356                    let mut typed_defs = Vec::new();
1357                    for (v, ann, d) in bindings {
1358                        let (p1, t1, typed_def) = if let Some(ref ann_expr) = ann {
1359                            let mut ann_vars = HashMap::new();
1360                            let ann_ty = type_from_annotation_expr_vars(
1361                                adts,
1362                                ann_expr,
1363                                &mut ann_vars,
1364                                supply,
1365                            )?;
1366                            match d.as_ref() {
1367                                Expr::RecordUpdate(_, base, updates) => {
1368                                    infer_record_update_typed_with_hint(
1369                                        unifier,
1370                                        supply,
1371                                        &env_cur,
1372                                        adts,
1373                                        &known_cur,
1374                                        base.as_ref(),
1375                                        updates,
1376                                        &ann_ty,
1377                                    )?
1378                                }
1379                                _ => {
1380                                    let (p1, t1, typed_def) = infer_expr(
1381                                        unifier, supply, &env_cur, adts, &known_cur, &d,
1382                                    )?;
1383                                    unifier.unify(&t1, &ann_ty)?;
1384                                    (p1, t1, typed_def)
1385                                }
1386                            }
1387                        } else {
1388                            infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?
1389                        };
1390                        let def_ty = unifier.apply_type(&t1);
1391                        let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
1392                            monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
1393                        } else {
1394                            let scheme =
1395                                generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
1396                            reject_ambiguous_scheme(&scheme)?;
1397                            scheme
1398                        };
1399                        env_cur.extend(v.name.clone(), scheme);
1400                        if let Some(known_variant) =
1401                            known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
1402                        {
1403                            known_cur.insert(
1404                                v.name.clone(),
1405                                KnownVariant {
1406                                    adt: known_variant.adt,
1407                                    variant: known_variant.variant,
1408                                },
1409                            );
1410                        } else {
1411                            known_cur.remove(&v.name);
1412                        }
1413                        typed_defs.push((v.name.clone(), typed_def));
1414                    }
1415
1416                    let (p_body, t_body, typed_body) =
1417                        infer_expr(unifier, supply, &env_cur, adts, &known_cur, cur)?;
1418
1419                    let mut typed = typed_body;
1420                    for (name, def) in typed_defs.into_iter().rev() {
1421                        typed = TypedExpr::new(
1422                            t_body.clone(),
1423                            TypedExprKind::Let {
1424                                name,
1425                                def: Box::new(def),
1426                                body: Box::new(typed),
1427                            },
1428                        );
1429                    }
1430                    Ok((p_body, t_body, typed))
1431                }
1432                Expr::LetRec(_, bindings, body) => {
1433                    let mut env_seed = env.clone();
1434                    let mut known_seed = known.clone();
1435                    let mut binding_tys = HashMap::new();
1436                    for (var, _ann, _def) in bindings {
1437                        let tv = Type::var(supply.fresh(Some(var.name.clone())));
1438                        env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
1439                        known_seed.remove(&var.name);
1440                        binding_tys.insert(var.name.clone(), tv);
1441                    }
1442
1443                    let mut inferred_defs = Vec::with_capacity(bindings.len());
1444                    for (var, ann, def) in bindings {
1445                        let (preds, def_ty, typed_def) =
1446                            infer_expr(unifier, supply, &env_seed, adts, &known_seed, def)?;
1447                        if let Some(ann) = ann {
1448                            let mut ann_vars = HashMap::new();
1449                            let ann_ty =
1450                                type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
1451                            unifier.unify(&def_ty, &ann_ty)?;
1452                        }
1453                        let binding_ty = binding_tys
1454                            .get(&var.name)
1455                            .cloned()
1456                            .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1457                        unifier.unify(&binding_ty, &def_ty)?;
1458                        let resolved_ty = unifier.apply_type(&binding_ty);
1459
1460                        if let Some(known_variant) =
1461                            known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
1462                        {
1463                            known_seed.insert(
1464                                var.name.clone(),
1465                                KnownVariant {
1466                                    adt: known_variant.adt,
1467                                    variant: known_variant.variant,
1468                                },
1469                            );
1470                        } else {
1471                            known_seed.remove(&var.name);
1472                        }
1473                        inferred_defs.push((var.name.clone(), preds, resolved_ty, typed_def));
1474                    }
1475
1476                    let mut env_body = env.clone();
1477                    let mut typed_bindings = Vec::with_capacity(inferred_defs.len());
1478                    for (name, preds, def_ty, typed_def) in inferred_defs {
1479                        let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
1480                        reject_ambiguous_scheme(&scheme)?;
1481                        env_body.extend(name.clone(), scheme);
1482                        typed_bindings.push((name, typed_def));
1483                    }
1484
1485                    let (p_body, t_body, typed_body) =
1486                        infer_expr(unifier, supply, &env_body, adts, &known_seed, body)?;
1487                    let typed = TypedExpr::new(
1488                        t_body.clone(),
1489                        TypedExprKind::LetRec {
1490                            bindings: typed_bindings,
1491                            body: Box::new(typed_body),
1492                        },
1493                    );
1494                    Ok((p_body, t_body, typed))
1495                }
1496                Expr::Ite(_, cond, then_expr, else_expr) => {
1497                    let (p1, t1, typed_cond) = infer_expr(unifier, supply, env, adts, known, cond)?;
1498                    unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
1499                    let (p2, t2, typed_then) =
1500                        infer_expr(unifier, supply, env, adts, known, then_expr)?;
1501                    let (p3, t3, typed_else) =
1502                        infer_expr(unifier, supply, env, adts, known, else_expr)?;
1503                    unifier.unify(&t2, &t3)?;
1504                    let out_ty = unifier.apply_type(&t2);
1505                    let mut preds = p1;
1506                    preds.extend(p2);
1507                    preds.extend(p3);
1508                    let typed = TypedExpr::new(
1509                        out_ty.clone(),
1510                        TypedExprKind::Ite {
1511                            cond: Box::new(typed_cond),
1512                            then_expr: Box::new(typed_then),
1513                            else_expr: Box::new(typed_else),
1514                        },
1515                    );
1516                    Ok((preds, out_ty, typed))
1517                }
1518                Expr::Tuple(_, elems) => {
1519                    let mut preds = Vec::new();
1520                    let mut types = Vec::new();
1521                    let mut typed_elems = Vec::new();
1522                    for elem in elems {
1523                        let (p1, t1, typed_elem) =
1524                            infer_expr(unifier, supply, env, adts, known, elem)?;
1525                        preds.extend(p1);
1526                        types.push(unifier.apply_type(&t1));
1527                        typed_elems.push(typed_elem);
1528                    }
1529                    let tuple_ty = Type::tuple(types);
1530                    let typed = TypedExpr::new(tuple_ty.clone(), TypedExprKind::Tuple(typed_elems));
1531                    Ok((preds, tuple_ty, typed))
1532                }
1533                Expr::List(_, elems) => {
1534                    let elem_tv = Type::var(supply.fresh(Some("a".into())));
1535                    let mut preds = Vec::new();
1536                    let mut typed_elems = Vec::new();
1537                    for elem in elems {
1538                        let (p1, t1, typed_elem) =
1539                            infer_expr(unifier, supply, env, adts, known, elem)?;
1540                        unifier.unify(&t1, &elem_tv)?;
1541                        preds.extend(p1);
1542                        typed_elems.push(typed_elem);
1543                    }
1544                    let list_ty = Type::app(
1545                        Type::builtin(BuiltinTypeId::List),
1546                        unifier.apply_type(&elem_tv),
1547                    );
1548                    let typed = TypedExpr::new(list_ty.clone(), TypedExprKind::List(typed_elems));
1549                    Ok((preds, list_ty, typed))
1550                }
1551                Expr::Dict(_, kvs) => {
1552                    let elem_tv = Type::var(supply.fresh(Some("v".into())));
1553                    let mut preds = Vec::new();
1554                    let mut typed_kvs = BTreeMap::new();
1555                    for (k, v) in kvs {
1556                        let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v)?;
1557                        unifier.unify(&t1, &elem_tv)?;
1558                        preds.extend(p1);
1559                        typed_kvs.insert(k.clone(), typed_v);
1560                    }
1561                    let dict_ty = Type::app(
1562                        Type::builtin(BuiltinTypeId::Dict),
1563                        unifier.apply_type(&elem_tv),
1564                    );
1565                    let typed = TypedExpr::new(dict_ty.clone(), TypedExprKind::Dict(typed_kvs));
1566                    Ok((preds, dict_ty, typed))
1567                }
1568                Expr::Match(_, scrutinee, arms) => {
1569                    let (p1, t1, typed_scrutinee) =
1570                        infer_expr(unifier, supply, env, adts, known, scrutinee)?;
1571                    let mut preds = p1;
1572                    let mut typed_arms = Vec::new();
1573                    let res_ty = Type::var(supply.fresh(Some("match".into())));
1574                    let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
1575
1576                    for (pat, expr) in arms {
1577                        let scrutinee_ty = unifier.apply_type(&t1);
1578                        let (p_pat, binds) =
1579                            infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
1580                        preds.extend(p_pat);
1581
1582                        let mut env_arm = env.clone();
1583                        for (name, ty) in binds {
1584                            env_arm
1585                                .extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1586                        }
1587                        let mut known_arm = known.clone();
1588                        if let Expr::Var(var) = scrutinee.as_ref() {
1589                            match pat {
1590                                Pattern::Named(_, name, _) => {
1591                                    let name_sym = name.to_dotted_symbol();
1592                                    if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1593                                        known_arm.insert(
1594                                            var.name.clone(),
1595                                            KnownVariant {
1596                                                adt: adt.name.clone(),
1597                                                variant: name_sym,
1598                                            },
1599                                        );
1600                                    } else {
1601                                        known_arm.remove(&var.name);
1602                                    }
1603                                }
1604                                _ => {
1605                                    known_arm.remove(&var.name);
1606                                }
1607                            }
1608                        }
1609                        let (p_expr, t_expr, typed_expr) =
1610                            infer_expr(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1611                        unifier.unify(&res_ty, &t_expr)?;
1612                        preds.extend(p_expr);
1613                        typed_arms.push((pat.clone(), typed_expr));
1614                    }
1615
1616                    let scrutinee_ty = unifier.apply_type(&t1);
1617                    check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1618                    let out_ty = unifier.apply_type(&res_ty);
1619                    let typed = TypedExpr::new(
1620                        out_ty.clone(),
1621                        TypedExprKind::Match {
1622                            scrutinee: Box::new(typed_scrutinee),
1623                            arms: typed_arms,
1624                        },
1625                    );
1626                    Ok((preds, out_ty, typed))
1627                }
1628                Expr::Ann(_, expr, ann) => {
1629                    let ann_ty = type_from_annotation_expr(adts, ann)?;
1630                    match expr.as_ref() {
1631                        Expr::RecordUpdate(_, base, updates) => {
1632                            infer_record_update_typed_with_hint(
1633                                unifier,
1634                                supply,
1635                                env,
1636                                adts,
1637                                known,
1638                                base.as_ref(),
1639                                updates,
1640                                &ann_ty,
1641                            )
1642                        }
1643                        _ => {
1644                            let (preds, expr_ty, typed_expr) =
1645                                infer_expr(unifier, supply, env, adts, known, expr)?;
1646                            unifier.unify(&expr_ty, &ann_ty)?;
1647                            let out_ty = unifier.apply_type(&ann_ty);
1648                            Ok((preds, out_ty, typed_expr))
1649                        }
1650                    }
1651                }
1652            }
1653        })()
1654    });
1655    res.map_err(|err| with_span(&span, err))
1656}
1657
1658fn ctor_lookup<'a>(
1659    adts: &'a HashMap<Symbol, AdtDecl>,
1660    name: &Symbol,
1661) -> Option<(&'a AdtDecl, &'a AdtVariant)> {
1662    let mut found = None;
1663    for adt in adts.values() {
1664        if let Some(variant) = adt.variants.iter().find(|v| &v.name == name) {
1665            if found.is_some() {
1666                return None;
1667            }
1668            found = Some((adt, variant));
1669        }
1670    }
1671    found
1672}
1673
1674fn record_fields(variant: &AdtVariant) -> Option<&[(Symbol, Type)]> {
1675    if variant.args.len() != 1 {
1676        return None;
1677    }
1678    match variant.args[0].as_ref() {
1679        TypeKind::Record(fields) => Some(fields),
1680        _ => None,
1681    }
1682}
1683
1684fn instantiate_variant_fields(
1685    adt: &AdtDecl,
1686    variant: &AdtVariant,
1687    supply: &mut TypeVarSupply,
1688) -> Option<(Type, Vec<(Symbol, Type)>)> {
1689    let fields = record_fields(variant)?;
1690    let mut subst = Subst::new_sync();
1691    for param in &adt.params {
1692        let fresh = Type::var(supply.fresh(param.var.name.clone()));
1693        subst = subst.insert(param.var.id, fresh);
1694    }
1695    let result_ty = adt.result_type().apply(&subst);
1696    let fields = fields
1697        .iter()
1698        .map(|(name, ty)| (name.clone(), ty.apply(&subst)))
1699        .collect();
1700    Some((result_ty, fields))
1701}
1702
1703fn known_variant_from_expr(
1704    expr: &Expr,
1705    expr_ty: &Type,
1706    adts: &HashMap<Symbol, AdtDecl>,
1707) -> Option<KnownVariant> {
1708    let mut expr = expr;
1709    while let Expr::Ann(_, inner, _) = expr {
1710        expr = inner.as_ref();
1711    }
1712    if matches!(expr_ty.as_ref(), TypeKind::Fun(..)) {
1713        return None;
1714    }
1715    let ctor = match expr {
1716        Expr::App(_, f, _) => match f.as_ref() {
1717            Expr::Var(var) => var.name.clone(),
1718            _ => return None,
1719        },
1720        _ => return None,
1721    };
1722    let (adt, variant) = ctor_lookup(adts, &ctor)?;
1723    record_fields(variant)?;
1724    Some(KnownVariant {
1725        adt: adt.name.clone(),
1726        variant: variant.name.clone(),
1727    })
1728}
1729
1730fn known_variant_from_expr_with_known(
1731    expr: &Expr,
1732    expr_ty: &Type,
1733    adts: &HashMap<Symbol, AdtDecl>,
1734    known: &KnownVariants,
1735) -> Option<KnownVariant> {
1736    let mut expr = expr;
1737    while let Expr::Ann(_, inner, _) = expr {
1738        expr = inner.as_ref();
1739    }
1740    match expr {
1741        Expr::Var(var) => known.get(&var.name).cloned(),
1742        Expr::RecordUpdate(_, base, _) => {
1743            known_variant_from_expr_with_known(base.as_ref(), expr_ty, adts, known)
1744        }
1745        _ => known_variant_from_expr(expr, expr_ty, adts),
1746    }
1747}
1748
1749fn select_record_variant<'a, F>(
1750    adts: &'a HashMap<Symbol, AdtDecl>,
1751    base_ty: &Type,
1752    known_variant: Option<KnownVariant>,
1753    field_for_errors: &Symbol,
1754    matches_fields: F,
1755) -> Result<(&'a AdtDecl, &'a AdtVariant), TypeError>
1756where
1757    F: Fn(&[(Symbol, Type)]) -> bool,
1758{
1759    if let Some(info) = known_variant {
1760        let adt = adts
1761            .get(&info.adt)
1762            .ok_or_else(|| TypeError::UnknownTypeName(info.adt.clone()))?;
1763        let variant = adt
1764            .variants
1765            .iter()
1766            .find(|v| v.name == info.variant)
1767            .ok_or_else(|| TypeError::UnknownField {
1768                field: field_for_errors.clone(),
1769                typ: base_ty.to_string(),
1770            })?;
1771        return Ok((adt, variant));
1772    }
1773
1774    if let Some(adt_name) = type_head_name(base_ty) {
1775        let adt = adts.get(adt_name).ok_or_else(|| TypeError::UnknownField {
1776            field: field_for_errors.clone(),
1777            typ: base_ty.to_string(),
1778        })?;
1779        if adt.variants.len() == 1 {
1780            return Ok((adt, &adt.variants[0]));
1781        }
1782        return Err(TypeError::FieldNotKnown {
1783            field: field_for_errors.clone(),
1784            typ: base_ty.to_string(),
1785        });
1786    }
1787
1788    if matches!(base_ty.as_ref(), TypeKind::Var(_)) {
1789        let mut candidates = Vec::new();
1790        for adt in adts.values() {
1791            if adt.variants.len() != 1 {
1792                continue;
1793            }
1794            let variant = &adt.variants[0];
1795            let Some(fields) = record_fields(variant) else {
1796                continue;
1797            };
1798            if matches_fields(fields) {
1799                candidates.push((adt, variant));
1800            }
1801        }
1802        if candidates.len() == 1 {
1803            return Ok(candidates.remove(0));
1804        }
1805        if candidates.is_empty() {
1806            return Err(TypeError::UnknownField {
1807                field: field_for_errors.clone(),
1808                typ: base_ty.to_string(),
1809            });
1810        }
1811        return Err(TypeError::FieldNotKnown {
1812            field: field_for_errors.clone(),
1813            typ: base_ty.to_string(),
1814        });
1815    }
1816
1817    Err(TypeError::UnknownField {
1818        field: field_for_errors.clone(),
1819        typ: base_ty.to_string(),
1820    })
1821}
1822
1823fn resolve_record_update(
1824    unifier: &mut Unifier<'_>,
1825    supply: &mut TypeVarSupply,
1826    adts: &HashMap<Symbol, AdtDecl>,
1827    base_ty: &Type,
1828    known_variant: Option<KnownVariant>,
1829    update_fields: &[Symbol],
1830) -> Result<(Type, Vec<(Symbol, Type)>), TypeError> {
1831    if let TypeKind::Record(fields) = base_ty.as_ref() {
1832        return Ok((base_ty.clone(), fields.clone()));
1833    }
1834
1835    let field_for_errors = update_fields.first().cloned().unwrap_or_else(|| sym("_"));
1836
1837    let (adt, variant) =
1838        select_record_variant(adts, base_ty, known_variant, &field_for_errors, |fields| {
1839            update_fields
1840                .iter()
1841                .all(|field| fields.iter().any(|(name, _)| name == field))
1842        })?;
1843
1844    let (result_ty, fields) =
1845        instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1846            TypeError::UnknownField {
1847                field: field_for_errors.clone(),
1848                typ: base_ty.to_string(),
1849            }
1850        })?;
1851
1852    for field in update_fields {
1853        if fields.iter().all(|(name, _)| name != field) {
1854            return Err(TypeError::UnknownField {
1855                field: field.clone(),
1856                typ: base_ty.to_string(),
1857            });
1858        }
1859    }
1860
1861    unifier.unify(base_ty, &result_ty)?;
1862    let result_ty = unifier.apply_type(&result_ty);
1863    let fields = fields
1864        .into_iter()
1865        .map(|(name, ty)| (name, unifier.apply_type(&ty)))
1866        .collect();
1867    Ok((result_ty, fields))
1868}
1869
1870fn resolve_projection(
1871    unifier: &mut Unifier<'_>,
1872    supply: &mut TypeVarSupply,
1873    adts: &HashMap<Symbol, AdtDecl>,
1874    base_ty: &Type,
1875    known_variant: Option<KnownVariant>,
1876    field: &Symbol,
1877) -> Result<Type, TypeError> {
1878    if let Ok(index) = field.as_ref().parse::<usize>() {
1879        let elem_ty = match base_ty.as_ref() {
1880            TypeKind::Tuple(elems) => {
1881                elems
1882                    .get(index)
1883                    .cloned()
1884                    .ok_or_else(|| TypeError::UnknownField {
1885                        field: field.clone(),
1886                        typ: base_ty.to_string(),
1887                    })?
1888            }
1889            TypeKind::Var(_) => {
1890                let mut elems = Vec::with_capacity(index + 1);
1891                for _ in 0..=index {
1892                    elems.push(Type::var(supply.fresh(Some(sym("t")))));
1893                }
1894                let tuple_ty = Type::tuple(elems.clone());
1895                unifier.unify(base_ty, &tuple_ty)?;
1896                elems[index].clone()
1897            }
1898            _ => {
1899                return Err(TypeError::UnknownField {
1900                    field: field.clone(),
1901                    typ: base_ty.to_string(),
1902                });
1903            }
1904        };
1905        return Ok(unifier.apply_type(&elem_ty));
1906    }
1907
1908    let (adt, variant) = select_record_variant(adts, base_ty, known_variant, field, |fields| {
1909        fields.iter().any(|(name, _)| name == field)
1910    })?;
1911
1912    let (result_ty, fields) =
1913        instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1914            TypeError::UnknownField {
1915                field: field.clone(),
1916                typ: base_ty.to_string(),
1917            }
1918        })?;
1919    let field_ty = fields
1920        .iter()
1921        .find(|(name, _)| name == field)
1922        .map(|(_, ty)| ty.clone())
1923        .ok_or_else(|| TypeError::UnknownField {
1924            field: field.clone(),
1925            typ: base_ty.to_string(),
1926        })?;
1927    unifier.unify(base_ty, &result_ty)?;
1928    Ok(unifier.apply_type(&field_ty))
1929}
1930
1931fn decompose_fun(typ: &Type, arity: usize) -> Option<(Vec<Type>, Type)> {
1932    let mut args = Vec::with_capacity(arity);
1933    let mut cur = typ.clone();
1934    for _ in 0..arity {
1935        match cur.as_ref() {
1936            TypeKind::Fun(a, b) => {
1937                args.push(a.clone());
1938                cur = b.clone();
1939            }
1940            _ => return None,
1941        }
1942    }
1943    Some((args, cur))
1944}
1945
1946type InferPatternResult = (Vec<Predicate>, Vec<(Symbol, Type)>);
1947
1948fn infer_pattern(
1949    unifier: &mut Unifier<'_>,
1950    supply: &mut TypeVarSupply,
1951    env: &TypeEnv,
1952    pat: &Pattern,
1953    scrutinee_ty: &Type,
1954) -> Result<InferPatternResult, TypeError> {
1955    let span = *pat.span();
1956    let res = (|| {
1957        unifier.charge_infer_node()?;
1958        match pat {
1959            Pattern::Wildcard(..) => Ok((vec![], vec![])),
1960            Pattern::Var(var) => Ok((
1961                vec![],
1962                vec![(var.name.clone(), unifier.apply_type(scrutinee_ty))],
1963            )),
1964            Pattern::Named(_, name, ps) => {
1965                let ctor_name = name.to_dotted_symbol();
1966                let schemes = env
1967                    .lookup(&ctor_name)
1968                    .ok_or_else(|| TypeError::UnknownVar(ctor_name.clone()))?;
1969                if schemes.len() != 1 {
1970                    return Err(TypeError::AmbiguousOverload(ctor_name));
1971                }
1972                let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1973                let (preds, ctor_ty) = instantiate(&scheme, supply);
1974                let (arg_tys, res_ty) = decompose_fun(&ctor_ty, ps.len())
1975                    .ok_or(TypeError::UnsupportedExpr("pattern constructor"))?;
1976                unifier.unify(&res_ty, scrutinee_ty)?;
1977                let mut all_preds = preds;
1978                let mut bindings = Vec::new();
1979                for (p, arg_ty) in ps.iter().zip(arg_tys.iter()) {
1980                    let arg_ty = unifier.apply_type(arg_ty);
1981                    let (p1, binds1) = infer_pattern(unifier, supply, env, p, &arg_ty)?;
1982                    all_preds.extend(p1);
1983                    bindings.extend(binds1);
1984                }
1985                let bindings = bindings
1986                    .into_iter()
1987                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
1988                    .collect();
1989                Ok((all_preds, bindings))
1990            }
1991            Pattern::List(_, ps) => {
1992                let elem_tv = Type::var(supply.fresh(Some("a".into())));
1993                let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
1994                unifier.unify(scrutinee_ty, &list_ty)?;
1995                let mut preds = Vec::new();
1996                let mut bindings = Vec::new();
1997                for p in ps {
1998                    let elem_ty = unifier.apply_type(&elem_tv);
1999                    let (p1, binds1) = infer_pattern(unifier, supply, env, p, &elem_ty)?;
2000                    preds.extend(p1);
2001                    bindings.extend(binds1);
2002                }
2003                let bindings = bindings
2004                    .into_iter()
2005                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2006                    .collect();
2007                Ok((preds, bindings))
2008            }
2009            Pattern::Cons(_, head, tail) => {
2010                let elem_tv = Type::var(supply.fresh(Some("a".into())));
2011                let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2012                unifier.unify(scrutinee_ty, &list_ty)?;
2013                let mut preds = Vec::new();
2014                let mut bindings = Vec::new();
2015
2016                let head_ty = unifier.apply_type(&elem_tv);
2017                let (p1, binds1) = infer_pattern(unifier, supply, env, head, &head_ty)?;
2018                preds.extend(p1);
2019                bindings.extend(binds1);
2020
2021                let tail_ty = Type::app(
2022                    Type::builtin(BuiltinTypeId::List),
2023                    unifier.apply_type(&elem_tv),
2024                );
2025                let (p2, binds2) = infer_pattern(unifier, supply, env, tail, &tail_ty)?;
2026                preds.extend(p2);
2027                bindings.extend(binds2);
2028
2029                let bindings = bindings
2030                    .into_iter()
2031                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2032                    .collect();
2033                Ok((preds, bindings))
2034            }
2035            Pattern::Tuple(_, elems) => {
2036                let mut elem_tys: Vec<Type> = (0..elems.len())
2037                    .map(|i| Type::var(supply.fresh(Some(format!("t{i}").into()))))
2038                    .collect();
2039                let expected = Type::tuple(elem_tys.clone());
2040                unifier.unify(scrutinee_ty, &expected)?;
2041                elem_tys = elem_tys
2042                    .into_iter()
2043                    .map(|t| unifier.apply_type(&t))
2044                    .collect();
2045
2046                let mut preds = Vec::new();
2047                let mut bindings = Vec::new();
2048                for (p, ty) in elems.iter().zip(elem_tys.iter()) {
2049                    let (p_preds, p_binds) = infer_pattern(unifier, supply, env, p, ty)?;
2050                    preds.extend(p_preds);
2051                    bindings.extend(p_binds);
2052                }
2053                let bindings = bindings
2054                    .into_iter()
2055                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2056                    .collect();
2057                Ok((preds, bindings))
2058            }
2059            Pattern::Dict(_, fields) => {
2060                if let TypeKind::Record(ty_fields) = scrutinee_ty.as_ref() {
2061                    let mut preds = Vec::new();
2062                    let mut bindings = Vec::new();
2063                    for (key, pat) in fields {
2064                        let ty = ty_fields
2065                            .iter()
2066                            .find(|(name, _)| name == key)
2067                            .map(|(_, ty)| unifier.apply_type(ty))
2068                            .ok_or_else(|| TypeError::UnknownField {
2069                                field: key.clone(),
2070                                typ: scrutinee_ty.to_string(),
2071                            })?;
2072                        let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &ty)?;
2073                        preds.extend(p_preds);
2074                        bindings.extend(p_binds);
2075                    }
2076                    let bindings = bindings
2077                        .into_iter()
2078                        .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2079                        .collect();
2080                    Ok((preds, bindings))
2081                } else {
2082                    let elem_tv = Type::var(supply.fresh(Some("v".into())));
2083                    let dict_ty = Type::app(Type::builtin(BuiltinTypeId::Dict), elem_tv.clone());
2084                    unifier.unify(scrutinee_ty, &dict_ty)?;
2085                    let elem_ty = unifier.apply_type(&elem_tv);
2086
2087                    let mut preds = Vec::new();
2088                    let mut bindings = Vec::new();
2089                    for (_key, pat) in fields {
2090                        let (p_preds, p_binds) =
2091                            infer_pattern(unifier, supply, env, pat, &elem_ty)?;
2092                        preds.extend(p_preds);
2093                        bindings.extend(p_binds);
2094                    }
2095                    let bindings = bindings
2096                        .into_iter()
2097                        .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2098                        .collect();
2099                    Ok((preds, bindings))
2100                }
2101            }
2102        }
2103    })();
2104    res.map_err(|err| with_span(&span, err))
2105}
2106
2107fn type_head_name(typ: &Type) -> Option<&Symbol> {
2108    let mut cur = typ;
2109    while let TypeKind::App(head, _) = cur.as_ref() {
2110        cur = head;
2111    }
2112    match cur.as_ref() {
2113        TypeKind::Con(tc) => Some(&tc.name),
2114        _ => None,
2115    }
2116}
2117
2118fn adt_name_from_patterns(adts: &HashMap<Symbol, AdtDecl>, patterns: &[Pattern]) -> Option<Symbol> {
2119    let mut candidate: Option<Symbol> = None;
2120    for pat in patterns {
2121        let next = match pat {
2122            Pattern::Named(_, name, _) => {
2123                let name_sym = name.to_dotted_symbol();
2124                ctor_lookup(adts, &name_sym).map(|(adt, _)| adt.name.clone())
2125            }
2126            Pattern::List(..) | Pattern::Cons(..) => Some(sym("List")),
2127            _ => None,
2128        };
2129        if let Some(next) = next {
2130            match &candidate {
2131                None => candidate = Some(next),
2132                Some(prev) if *prev == next => {}
2133                Some(_) => return None,
2134            }
2135        }
2136    }
2137    candidate
2138}
2139
2140fn check_match_exhaustive(
2141    adts: &HashMap<Symbol, AdtDecl>,
2142    scrutinee_ty: &Type,
2143    patterns: &[Pattern],
2144) -> Result<(), TypeError> {
2145    if patterns
2146        .iter()
2147        .any(|p| matches!(p, Pattern::Wildcard(..) | Pattern::Var(_)))
2148    {
2149        return Ok(());
2150    }
2151    let adt_name = match type_head_name(scrutinee_ty).cloned() {
2152        Some(name) => name,
2153        None => match adt_name_from_patterns(adts, patterns) {
2154            Some(name) => name,
2155            None => return Ok(()),
2156        },
2157    };
2158    let adt = match adts.get(&adt_name) {
2159        Some(adt) => adt,
2160        None => return Ok(()),
2161    };
2162    let ctor_names: HashSet<Symbol> = adt.variants.iter().map(|v| v.name.clone()).collect();
2163    if ctor_names.is_empty() {
2164        return Ok(());
2165    }
2166    let mut covered = HashSet::new();
2167    for pat in patterns {
2168        match pat {
2169            Pattern::Named(_, name, _) => {
2170                let name_sym = name.to_dotted_symbol();
2171                if ctor_names.contains(&name_sym) {
2172                    covered.insert(name_sym);
2173                }
2174            }
2175            Pattern::List(_, elems) if adt_name.as_ref() == "List" && elems.is_empty() => {
2176                covered.insert(sym("Empty"));
2177            }
2178            Pattern::Cons(..) if adt_name.as_ref() == "List" => {
2179                covered.insert(sym("Cons"));
2180            }
2181            _ => {}
2182        }
2183    }
2184    let mut missing: Vec<Symbol> = ctor_names.difference(&covered).cloned().collect();
2185    if missing.is_empty() {
2186        return Ok(());
2187    }
2188    missing.sort();
2189    Err(TypeError::NonExhaustiveMatch {
2190        typ: scrutinee_ty.to_string(),
2191        missing,
2192    })
2193}
2194
2195#[cfg(test)]
2196mod tests {
2197    use super::*;
2198    use rexlang_lexer::Token;
2199    use rexlang_parser::Parser;
2200    use rexlang_util::{GasCosts, GasMeter};
2201
2202    fn tvar(id: TypeVarId, name: &str) -> Type {
2203        Type::var(TypeVar::new(id, Some(sym(name))))
2204    }
2205
2206    fn dict_of(elem: Type) -> Type {
2207        Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
2208    }
2209
2210    #[test]
2211    fn unify_simple() {
2212        let t1 = Type::fun(tvar(0, "a"), Type::builtin(BuiltinTypeId::U32));
2213        let t2 = Type::fun(Type::builtin(BuiltinTypeId::U16), tvar(1, "b"));
2214        let subst = unify(&t1, &t2).unwrap();
2215        assert_eq!(subst.get(&0), Some(&Type::builtin(BuiltinTypeId::U16)));
2216        assert_eq!(subst.get(&1), Some(&Type::builtin(BuiltinTypeId::U32)));
2217    }
2218
2219    #[test]
2220    fn occurs_check_blocks_infinite_type() {
2221        let tv = TypeVar::new(0, Some(sym("a")));
2222        let t = Type::fun(Type::var(tv.clone()), Type::builtin(BuiltinTypeId::U8));
2223        let err = bind(&tv, &t).unwrap_err();
2224        assert!(matches!(err, TypeError::Occurs(_, _)));
2225    }
2226
2227    #[test]
2228    fn instantiate_and_generalize_round_trip() {
2229        let mut supply = TypeVarSupply::new();
2230        let a = Type::var(supply.fresh(Some(sym("a"))));
2231        let scheme = generalize(&TypeEnv::new(), vec![], Type::fun(a.clone(), a.clone()));
2232        let (preds, inst) = instantiate(&scheme, &mut supply);
2233        assert!(preds.is_empty());
2234        if let TypeKind::Fun(l, r) = inst.as_ref() {
2235            match (l.as_ref(), r.as_ref()) {
2236                (TypeKind::Var(_), TypeKind::Var(_)) => {}
2237                _ => panic!("expected polymorphic identity"),
2238            }
2239        } else {
2240            panic!("expected function type");
2241        }
2242    }
2243
2244    #[test]
2245    fn entail_superclasses() {
2246        let ts = TypeSystem::new_with_prelude().unwrap();
2247        let pred = Predicate::new("Semiring", Type::builtin(BuiltinTypeId::I32));
2248        let given = [Predicate::new(
2249            "AdditiveGroup",
2250            Type::builtin(BuiltinTypeId::I32),
2251        )];
2252        assert!(entails(&ts.classes, &given, &pred).unwrap());
2253    }
2254
2255    #[test]
2256    fn entail_instances() {
2257        let ts = TypeSystem::new_with_prelude().unwrap();
2258        let pred = Predicate::new("Field", Type::builtin(BuiltinTypeId::F32));
2259        assert!(entails(&ts.classes, &[], &pred).unwrap());
2260
2261        let pred_fail = Predicate::new("Field", Type::builtin(BuiltinTypeId::U32));
2262        assert!(!entails(&ts.classes, &[], &pred_fail).unwrap());
2263    }
2264
2265    #[test]
2266    fn prelude_injects_functions() {
2267        let ts = TypeSystem::new_with_prelude().unwrap();
2268        let minus = ts.env.lookup(&sym("-")).expect("minus in env");
2269        let div = ts.env.lookup(&sym("/")).expect("div in env");
2270        assert_eq!(minus.len(), 1);
2271        assert_eq!(div.len(), 1);
2272        let minus = &minus[0];
2273        let div = &div[0];
2274        assert_eq!(minus.preds.len(), 1);
2275        assert_eq!(minus.vars.len(), 1);
2276        assert_eq!(div.preds.len(), 1);
2277        assert_eq!(div.vars.len(), 1);
2278    }
2279
2280    #[test]
2281    fn adt_constructors_are_present() {
2282        let ts = TypeSystem::new_with_prelude().unwrap();
2283        assert!(ts.env.lookup(&sym("Empty")).is_some());
2284        assert!(ts.env.lookup(&sym("Cons")).is_some());
2285        assert!(ts.env.lookup(&sym("Ok")).is_some());
2286        assert!(ts.env.lookup(&sym("Err")).is_some());
2287        assert!(ts.env.lookup(&sym("Some")).is_some());
2288        assert!(ts.env.lookup(&sym("None")).is_some());
2289    }
2290
2291    fn parse_expr(code: &str) -> std::sync::Arc<rexlang_ast::expr::Expr> {
2292        let mut parser = Parser::new(Token::tokenize(code).unwrap());
2293        parser.parse_program(&mut GasMeter::default()).unwrap().expr
2294    }
2295
2296    fn parse_program(code: &str) -> rexlang_ast::expr::Program {
2297        let mut parser = Parser::new(Token::tokenize(code).unwrap());
2298        parser.parse_program(&mut GasMeter::default()).unwrap()
2299    }
2300
2301    #[test]
2302    fn infer_deep_list_does_not_overflow() {
2303        const N: usize = 40;
2304        let mut code = String::new();
2305        code.push_str("let xs = ");
2306        for _ in 0..N {
2307            code.push_str("Cons 0 (");
2308        }
2309        code.push_str("Empty");
2310        for _ in 0..N {
2311            code.push(')');
2312        }
2313        code.push_str(" in xs");
2314
2315        let parse_handle = std::thread::Builder::new()
2316            .name("infer_deep_list_parse".into())
2317            .stack_size(128 * 1024 * 1024)
2318            .spawn(move || {
2319                let tokens = Token::tokenize(&code).unwrap();
2320                let mut parser = Parser::new(tokens);
2321                parser.parse_program(&mut GasMeter::default())
2322            })
2323            .unwrap();
2324        let program = parse_handle.join().unwrap().unwrap();
2325        let expr = program.expr;
2326        let mut ts = TypeSystem::new_with_prelude().unwrap();
2327        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2328        assert_eq!(
2329            ty,
2330            Type::app(
2331                Type::builtin(BuiltinTypeId::List),
2332                Type::builtin(BuiltinTypeId::I32)
2333            )
2334        );
2335    }
2336
2337    #[test]
2338    fn collect_adts_in_types_finds_nested_unique_adts() {
2339        let foo = Type::user_con("Foo", 1);
2340        let bar = Type::user_con("Bar", 0);
2341        let ty = Type::fun(
2342            Type::app(
2343                Type::builtin(BuiltinTypeId::List),
2344                Type::app(foo.clone(), tvar(0, "a")),
2345            ),
2346            Type::tuple(vec![
2347                Type::app(foo.clone(), Type::builtin(BuiltinTypeId::I32)),
2348                bar.clone(),
2349            ]),
2350        );
2351
2352        let adts = collect_adts_in_types(vec![ty]).unwrap();
2353        assert_eq!(adts, vec![foo, bar]);
2354    }
2355
2356    #[test]
2357    fn collect_adts_in_types_rejects_conflicting_names() {
2358        let arity1 = Type::user_con("Thing", 1);
2359        let arity2 = Type::user_con("Thing", 2);
2360
2361        let err = collect_adts_in_types(vec![arity1.clone(), arity2.clone()]).unwrap_err();
2362        assert_eq!(err.conflicts.len(), 1);
2363        let conflict = &err.conflicts[0];
2364        assert_eq!(conflict.name, sym("Thing"));
2365        assert_eq!(conflict.definitions, vec![arity1, arity2]);
2366    }
2367
2368    #[test]
2369    fn infer_depth_limit_is_enforced() {
2370        const N: usize = 40;
2371        let mut code = String::new();
2372        code.push_str("let xs = ");
2373        for _ in 0..N {
2374            code.push_str("Cons 0 (");
2375        }
2376        code.push_str("Empty");
2377        for _ in 0..N {
2378            code.push(')');
2379        }
2380        code.push_str(" in xs");
2381
2382        let program = parse_program(&code);
2383        let mut ts = TypeSystem::new_with_prelude().unwrap();
2384        ts.set_limits(TypeSystemLimits {
2385            max_infer_depth: Some(8),
2386        });
2387
2388        let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
2389        assert!(
2390            err.to_string().contains("maximum inference depth exceeded"),
2391            "expected a max-depth inference error, got: {err:?}"
2392        );
2393    }
2394
2395    #[test]
2396    fn declare_fn_injects_scheme_for_use_sites() {
2397        let program = parse_program(
2398            r#"
2399            declare fn id x: a -> a
2400            id 1
2401            "#,
2402        );
2403        let mut ts = TypeSystem::new_with_prelude().unwrap();
2404        ts.register_decls(&program.decls).unwrap();
2405        let (preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2406        assert!(
2407            preds.is_empty()
2408                || preds.iter().all(|p| p.class.as_ref() == "Integral"
2409                    && p.typ == Type::builtin(BuiltinTypeId::I32))
2410        );
2411        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2412    }
2413
2414    #[test]
2415    fn declare_fn_is_noop_when_matching_existing_scheme() {
2416        let mut ts = TypeSystem::new_with_prelude().unwrap();
2417        ts.add_value(
2418            "foo",
2419            Scheme::new(
2420                vec![],
2421                vec![],
2422                Type::fun(
2423                    Type::builtin(BuiltinTypeId::I32),
2424                    Type::builtin(BuiltinTypeId::I32),
2425                ),
2426            ),
2427        );
2428
2429        let program = parse_program(
2430            r#"
2431            declare fn foo x: i32 -> i32
2432            0
2433            "#,
2434        );
2435        let rexlang_ast::expr::Decl::DeclareFn(fd) = &program.decls[0] else {
2436            panic!("expected declare fn decl");
2437        };
2438        ts.inject_declare_fn_decl(fd).unwrap();
2439    }
2440
2441    #[test]
2442    fn unit_type_parses_and_infers() {
2443        let program = parse_program(
2444            r#"
2445            fn unit_id x: () -> () = x
2446            unit_id ()
2447            "#,
2448        );
2449        let mut ts = TypeSystem::new_with_prelude().unwrap();
2450        ts.register_decls(&program.decls).unwrap();
2451        let (preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2452        assert!(preds.is_empty());
2453        assert_eq!(ty, Type::tuple(vec![]));
2454    }
2455
2456    fn strip_span(mut err: TypeError) -> TypeError {
2457        while let TypeError::Spanned { error, .. } = err {
2458            err = *error;
2459        }
2460        err
2461    }
2462
2463    #[test]
2464    fn type_errors_include_span() {
2465        let expr = parse_expr("missing");
2466        let mut ts = TypeSystem::new_with_prelude().unwrap();
2467        let err = infer(&mut ts, expr.as_ref()).unwrap_err();
2468        match err {
2469            TypeError::Spanned { span, error } => {
2470                assert_ne!(span, Span::default());
2471                assert!(matches!(
2472                    *error,
2473                    TypeError::UnknownVar(name) if name.as_ref() == "missing"
2474                ));
2475            }
2476            other => panic!("expected spanned error, got {other:?}"),
2477        }
2478    }
2479
2480    #[test]
2481    fn infer_with_gas_rejects_out_of_budget() {
2482        let expr = parse_expr("1");
2483        let mut ts = TypeSystem::new_with_prelude().unwrap();
2484        let mut gas = GasMeter::new(
2485            Some(0),
2486            GasCosts {
2487                infer_node: 1,
2488                unify_step: 0,
2489                ..GasCosts::sensible_defaults()
2490            },
2491        );
2492        let err = infer_with_gas(&mut ts, expr.as_ref(), &mut gas).unwrap_err();
2493        assert!(matches!(strip_span(err), TypeError::OutOfGas(..)));
2494    }
2495
2496    #[test]
2497    fn reject_user_redefinition_of_primitive_type_name() {
2498        let program = parse_program("type i32 = I32Wrap i32");
2499        let mut ts = TypeSystem::new_with_prelude().unwrap();
2500        let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2501            panic!("expected type decl");
2502        };
2503        let err = ts.register_type_decl(decl).unwrap_err();
2504        assert!(matches!(
2505            err,
2506            TypeError::ReservedTypeName(name) if name.as_ref() == "i32"
2507        ));
2508    }
2509
2510    #[test]
2511    fn reject_user_redefinition_of_prelude_adt_name() {
2512        let program = parse_program("type Result e a = Nope e a");
2513        let mut ts = TypeSystem::new_with_prelude().unwrap();
2514        let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2515            panic!("expected type decl");
2516        };
2517        let err = ts.register_type_decl(decl).unwrap_err();
2518        assert!(matches!(
2519            err,
2520            TypeError::ReservedTypeName(name) if name.as_ref() == "Result"
2521        ));
2522    }
2523
2524    #[test]
2525    fn infer_polymorphic_id_tuple() {
2526        let expr = parse_expr(
2527            r#"
2528            let
2529                id = \x -> x
2530            in
2531                id (id 420, id 6.9, id "str")
2532            "#,
2533        );
2534        let mut ts = TypeSystem::new_with_prelude().unwrap();
2535        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2536        let expected = Type::tuple(vec![
2537            Type::builtin(BuiltinTypeId::I32),
2538            Type::builtin(BuiltinTypeId::F32),
2539            Type::builtin(BuiltinTypeId::String),
2540        ]);
2541        assert_eq!(ty, expected);
2542    }
2543
2544    #[test]
2545    fn infer_type_annotation_ok() {
2546        let expr = parse_expr("let x: i32 = 42 in x");
2547        let mut ts = TypeSystem::new_with_prelude().unwrap();
2548        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2549        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2550    }
2551
2552    #[test]
2553    fn infer_type_annotation_lambda_param() {
2554        let expr = parse_expr("\\ (a : f32) -> a");
2555        let mut ts = TypeSystem::new_with_prelude().unwrap();
2556        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2557        assert_eq!(
2558            ty,
2559            Type::fun(
2560                Type::builtin(BuiltinTypeId::F32),
2561                Type::builtin(BuiltinTypeId::F32)
2562            )
2563        );
2564    }
2565
2566    #[test]
2567    fn infer_type_annotation_is_alias() {
2568        let expr = parse_expr("\"hi\" is str");
2569        let mut ts = TypeSystem::new_with_prelude().unwrap();
2570        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2571        assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
2572    }
2573
2574    #[test]
2575    fn infer_type_annotation_mismatch_error() {
2576        let expr = parse_expr("let x: i32 = 3.14 in x");
2577        let mut ts = TypeSystem::new_with_prelude().unwrap();
2578        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
2579        assert!(matches!(err, TypeError::Unification(_, _)));
2580    }
2581
2582    #[test]
2583    fn infer_project_single_variant_let() {
2584        let program = parse_program(
2585            r#"
2586            type MyADT = MyVariant1 { field1: i32, field2: f32 }
2587            let
2588                x = MyVariant1 { field1 = 1, field2 = 2.0 }
2589            in
2590                (x.field1, x.field2)
2591            "#,
2592        );
2593        let mut ts = TypeSystem::new_with_prelude().unwrap();
2594        for decl in &program.decls {
2595            if let rexlang_ast::expr::Decl::Type(decl) = decl {
2596                ts.register_type_decl(decl).unwrap();
2597            }
2598        }
2599        let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2600        let expected = Type::tuple(vec![
2601            Type::builtin(BuiltinTypeId::I32),
2602            Type::builtin(BuiltinTypeId::F32),
2603        ]);
2604        assert_eq!(ty, expected);
2605    }
2606
2607    #[test]
2608    fn infer_project_known_variant_let() {
2609        let program = parse_program(
2610            r#"
2611            type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
2612            let
2613                x = MyVariant1 { field1 = 1, field2 = 2.0 }
2614            in
2615                x.field1
2616            "#,
2617        );
2618        let mut ts = TypeSystem::new_with_prelude().unwrap();
2619        for decl in &program.decls {
2620            if let rexlang_ast::expr::Decl::Type(decl) = decl {
2621                ts.register_type_decl(decl).unwrap();
2622            }
2623        }
2624        let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2625        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2626    }
2627
2628    #[test]
2629    fn infer_project_unknown_variant_error() {
2630        let program = parse_program(
2631            r#"
2632            type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
2633            let
2634                x = MyVariant2 1 2.0
2635            in
2636                x.field1
2637            "#,
2638        );
2639        let mut ts = TypeSystem::new_with_prelude().unwrap();
2640        for decl in &program.decls {
2641            if let rexlang_ast::expr::Decl::Type(decl) = decl {
2642                ts.register_type_decl(decl).unwrap();
2643            }
2644        }
2645        let err = strip_span(infer(&mut ts, program.expr.as_ref()).unwrap_err());
2646        assert!(matches!(err, TypeError::FieldNotKnown { .. }));
2647    }
2648
2649    #[test]
2650    fn infer_project_lambda_param_single_variant() {
2651        let program = parse_program(
2652            r#"
2653            type Boxed = Boxed { value: i32 }
2654            let
2655                f = \x -> x.value
2656            in
2657                f (Boxed { value = 1 })
2658            "#,
2659        );
2660        let mut ts = TypeSystem::new_with_prelude().unwrap();
2661        for decl in &program.decls {
2662            if let rexlang_ast::expr::Decl::Type(decl) = decl {
2663                ts.register_type_decl(decl).unwrap();
2664            }
2665        }
2666        let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2667        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2668    }
2669
2670    #[test]
2671    fn infer_project_in_match_arm() {
2672        let program = parse_program(
2673            r#"
2674            type MyADT = MyVariant1 { field1: i32 } | MyVariant2 i32
2675            let
2676                x = MyVariant1 { field1 = 1 }
2677            in
2678                match x
2679                    when MyVariant1 { field1 } -> x.field1
2680                    when MyVariant2 _ -> 0
2681            "#,
2682        );
2683        let mut ts = TypeSystem::new_with_prelude().unwrap();
2684        for decl in &program.decls {
2685            if let rexlang_ast::expr::Decl::Type(decl) = decl {
2686                ts.register_type_decl(decl).unwrap();
2687            }
2688        }
2689        let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2690        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2691    }
2692
2693    #[test]
2694    fn infer_nested_let_lambda_match_option() {
2695        let expr = parse_expr(
2696            r#"
2697            let
2698                choose = \flag a b -> if flag then a else b,
2699                build = \flag ->
2700                    let
2701                        pick = choose flag,
2702                        val = pick 1 2
2703                    in
2704                        Some val
2705            in
2706                match (build true)
2707                    when Some x -> x
2708                    when None -> 0
2709            "#,
2710        );
2711        let mut ts = TypeSystem::new_with_prelude().unwrap();
2712        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2713        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2714    }
2715
2716    #[test]
2717    fn infer_polymorphic_apply_in_tuple() {
2718        let expr = parse_expr(
2719            r#"
2720            let
2721                apply = \f x -> f x,
2722                id = \x -> x,
2723                wrap = \x -> (x, x)
2724            in
2725                (apply id 1, apply id "hi", apply wrap true)
2726            "#,
2727        );
2728        let mut ts = TypeSystem::new_with_prelude().unwrap();
2729        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2730        let expected = Type::tuple(vec![
2731            Type::builtin(BuiltinTypeId::I32),
2732            Type::builtin(BuiltinTypeId::String),
2733            Type::tuple(vec![
2734                Type::builtin(BuiltinTypeId::Bool),
2735                Type::builtin(BuiltinTypeId::Bool),
2736            ]),
2737        ]);
2738        assert_eq!(ty, expected);
2739    }
2740
2741    #[test]
2742    fn infer_nested_result_option_match() {
2743        let expr = parse_expr(
2744            r#"
2745            let
2746                unwrap = \x ->
2747                    match x
2748                        when Ok (Some v) -> v
2749                        when Ok None -> 0
2750                        when Err _ -> 0
2751            in
2752                unwrap (Ok (Some 5))
2753            "#,
2754        );
2755        let mut ts = TypeSystem::new_with_prelude().unwrap();
2756        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2757        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2758    }
2759
2760    #[test]
2761    fn infer_head_or_list_match() {
2762        let expr = parse_expr(
2763            r#"
2764            let
2765                head_or = \fallback xs ->
2766                    match xs
2767                        when [] -> fallback
2768                        when x::xs -> x
2769            in
2770                (head_or 0 [1, 2, 3], head_or 0 [])
2771            "#,
2772        );
2773        let mut ts = TypeSystem::new_with_prelude().unwrap();
2774        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2775        let expected = Type::tuple(vec![
2776            Type::builtin(BuiltinTypeId::I32),
2777            Type::builtin(BuiltinTypeId::I32),
2778        ]);
2779        assert_eq!(ty, expected);
2780    }
2781
2782    #[test]
2783    fn infer_head_or_list_match_cons_constructor_form() {
2784        let expr = parse_expr(
2785            r#"
2786            let
2787                head_or = \fallback xs ->
2788                    match xs
2789                        when [] -> fallback
2790                        when Cons x xs1 -> x
2791            in
2792                (head_or 0 (Cons 1 (Cons 2 Empty)), head_or 0 Empty)
2793            "#,
2794        );
2795        let mut ts = TypeSystem::new_with_prelude().unwrap();
2796        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2797        let expected = Type::tuple(vec![
2798            Type::builtin(BuiltinTypeId::I32),
2799            Type::builtin(BuiltinTypeId::I32),
2800        ]);
2801        assert_eq!(ty, expected);
2802    }
2803
2804    #[test]
2805    fn infer_record_pattern_in_lambda() {
2806        let program = parse_program(
2807            r#"
2808            type Pair = Pair { left: i32, right: i32 }
2809            let
2810                sum = \p ->
2811                    match p
2812                        when Pair { left, right } -> left + right
2813            in
2814                sum (Pair { left = 1, right = 2 })
2815            "#,
2816        );
2817        let mut ts = TypeSystem::new_with_prelude().unwrap();
2818        for decl in &program.decls {
2819            if let rexlang_ast::expr::Decl::Type(decl) = decl {
2820                ts.register_type_decl(decl).unwrap();
2821            }
2822        }
2823        let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2824        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2825    }
2826
2827    #[test]
2828    fn infer_fn_decl_simple() {
2829        let program = parse_program(
2830            r#"
2831            fn add (x: i32, y: i32) -> i32 = x + y
2832            add 1 2
2833            "#,
2834        );
2835        let mut ts = TypeSystem::new_with_prelude().unwrap();
2836        let expr = program.expr_with_fns();
2837        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2838        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2839    }
2840
2841    #[test]
2842    fn infer_fn_decl_signature_form() {
2843        let program = parse_program(
2844            r#"
2845            fn add : i32 -> i32 -> i32 = \x y -> x + y
2846            add 1 2
2847            "#,
2848        );
2849        let mut ts = TypeSystem::new_with_prelude().unwrap();
2850        let expr = program.expr_with_fns();
2851        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2852        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2853    }
2854
2855    #[test]
2856    fn infer_fn_decl_polymorphic_where_constraints() {
2857        let program = parse_program(
2858            r#"
2859            fn my_add (x: a, y: a) -> a where AdditiveMonoid a = x + y
2860            (my_add 1 2, my_add 1.0 2.0)
2861            "#,
2862        );
2863        let mut ts = TypeSystem::new_with_prelude().unwrap();
2864        let expr = program.expr_with_fns();
2865        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2866        assert_eq!(
2867            ty,
2868            Type::tuple(vec![
2869                Type::builtin(BuiltinTypeId::I32),
2870                Type::builtin(BuiltinTypeId::F32)
2871            ])
2872        );
2873    }
2874
2875    #[test]
2876    fn infer_additive_monoid_constraint() {
2877        let expr = parse_expr("\\x y -> x + y");
2878        let mut ts = TypeSystem::new_with_prelude().unwrap();
2879        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2880        assert_eq!(preds.len(), 1);
2881        assert_eq!(preds[0].class.as_ref(), "AdditiveMonoid");
2882
2883        if let TypeKind::Fun(a, rest) = ty.as_ref()
2884            && let TypeKind::Fun(b, c) = rest.as_ref()
2885        {
2886            assert_eq!(a.as_ref(), b.as_ref());
2887            assert_eq!(b.as_ref(), c.as_ref());
2888            assert_eq!(preds[0].typ, a.clone());
2889            return;
2890        }
2891        panic!("expected a -> a -> a");
2892    }
2893
2894    #[test]
2895    fn infer_multiplicative_monoid_constraint() {
2896        let expr = parse_expr("\\x y -> x * y");
2897        let mut ts = TypeSystem::new_with_prelude().unwrap();
2898        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2899        assert_eq!(preds.len(), 1);
2900        assert_eq!(preds[0].class.as_ref(), "MultiplicativeMonoid");
2901
2902        if let TypeKind::Fun(a, rest) = ty.as_ref()
2903            && let TypeKind::Fun(b, c) = rest.as_ref()
2904        {
2905            assert_eq!(a.as_ref(), b.as_ref());
2906            assert_eq!(b.as_ref(), c.as_ref());
2907            assert_eq!(preds[0].typ, a.clone());
2908            return;
2909        }
2910        panic!("expected a -> a -> a");
2911    }
2912
2913    #[test]
2914    fn infer_additive_group_constraint() {
2915        let expr = parse_expr("\\x y -> x - y");
2916        let mut ts = TypeSystem::new_with_prelude().unwrap();
2917        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2918        assert_eq!(preds.len(), 1);
2919        assert_eq!(preds[0].class.as_ref(), "AdditiveGroup");
2920
2921        if let TypeKind::Fun(a, rest) = ty.as_ref()
2922            && let TypeKind::Fun(b, c) = rest.as_ref()
2923        {
2924            assert_eq!(a.as_ref(), b.as_ref());
2925            assert_eq!(b.as_ref(), c.as_ref());
2926            assert_eq!(preds[0].typ, a.clone());
2927            return;
2928        }
2929        panic!("expected a -> a -> a");
2930    }
2931
2932    #[test]
2933    fn infer_integral_constraint() {
2934        let expr = parse_expr("\\x y -> x % y");
2935        let mut ts = TypeSystem::new_with_prelude().unwrap();
2936        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2937        assert_eq!(preds.len(), 1);
2938        assert_eq!(preds[0].class.as_ref(), "Integral");
2939
2940        if let TypeKind::Fun(a, rest) = ty.as_ref()
2941            && let TypeKind::Fun(b, c) = rest.as_ref()
2942        {
2943            assert_eq!(a.as_ref(), b.as_ref());
2944            assert_eq!(b.as_ref(), c.as_ref());
2945            assert_eq!(preds[0].typ, a.clone());
2946            return;
2947        }
2948        panic!("expected a -> a -> a");
2949    }
2950
2951    #[test]
2952    fn infer_literal_addition_defaults() {
2953        let expr = parse_expr("1 + 2");
2954        let mut ts = TypeSystem::new_with_prelude().unwrap();
2955        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2956        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2957        assert_eq!(preds.len(), 2);
2958        assert!(preds.iter().any(|p| p.class.as_ref() == "AdditiveMonoid"));
2959        assert!(preds.iter().any(|p| p.class.as_ref() == "Integral"));
2960        assert!(
2961            preds
2962                .iter()
2963                .all(|p| p.typ == Type::builtin(BuiltinTypeId::I32))
2964        );
2965    }
2966
2967    #[test]
2968    fn infer_mod_defaults() {
2969        let expr = parse_expr("1 % 2");
2970        let mut ts = TypeSystem::new_with_prelude().unwrap();
2971        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2972        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2973        assert_eq!(preds.len(), 1);
2974        assert_eq!(preds[0].class.as_ref(), "Integral");
2975        assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::I32));
2976    }
2977
2978    #[test]
2979    fn infer_get_list_type() {
2980        let expr = parse_expr("get 1 [1, 2, 3]");
2981        let mut ts = TypeSystem::new_with_prelude().unwrap();
2982        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2983        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2984        assert!(preds.iter().any(|p| p.class.as_ref() == "Indexable"));
2985        assert!(preds.iter().all(|p| {
2986            p.class.as_ref() == "Indexable"
2987                || (p.class.as_ref() == "Integral" && p.typ == Type::builtin(BuiltinTypeId::I32))
2988        }));
2989        for pred in preds.iter().filter(|p| p.class.as_ref() == "Indexable") {
2990            assert!(entails(&ts.classes, &[], pred).unwrap());
2991        }
2992    }
2993
2994    #[test]
2995    fn infer_get_tuple_type() {
2996        let expr = parse_expr("(1, 'Hello', true).0");
2997        let mut ts = TypeSystem::new_with_prelude().unwrap();
2998        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2999        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3000        assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3001
3002        let expr = parse_expr("(1, 'Hello', true).1");
3003        let mut ts = TypeSystem::new_with_prelude().unwrap();
3004        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3005        assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
3006        assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3007
3008        let expr = parse_expr("(1, 'Hello', true).2");
3009        let mut ts = TypeSystem::new_with_prelude().unwrap();
3010        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3011        assert_eq!(ty, Type::builtin(BuiltinTypeId::Bool));
3012        assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3013    }
3014
3015    #[test]
3016    fn infer_division_defaults() {
3017        let expr = parse_expr("1.0 / 2.0");
3018        let mut ts = TypeSystem::new_with_prelude().unwrap();
3019        let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3020        assert_eq!(ty, Type::builtin(BuiltinTypeId::F32));
3021        assert_eq!(preds.len(), 1);
3022        assert_eq!(preds[0].class.as_ref(), "Field");
3023        assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::F32));
3024        assert!(entails(&ts.classes, &[], &preds[0]).unwrap());
3025    }
3026
3027    #[test]
3028    fn infer_unbound_variable_error() {
3029        let expr = parse_expr("missing");
3030        let mut ts = TypeSystem::new_with_prelude().unwrap();
3031        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3032        assert!(matches!(
3033            err,
3034            TypeError::UnknownVar(name) if name.as_ref() == "missing"
3035        ));
3036    }
3037
3038    #[test]
3039    fn infer_if_branch_type_mismatch_error() {
3040        let expr = parse_expr(r#"if true then 1 else "no""#);
3041        let mut ts = TypeSystem::new_with_prelude().unwrap();
3042        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3043        match err {
3044            TypeError::Unification(a, b) => {
3045                let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
3046                assert!(ok, "expected i32 vs string, got {a} vs {b}");
3047            }
3048            other => panic!("expected unification error, got {other:?}"),
3049        }
3050    }
3051
3052    #[test]
3053    fn infer_unknown_pattern_constructor_error() {
3054        let expr = parse_expr("match 1 when Nope -> 1");
3055        let mut ts = TypeSystem::new_with_prelude().unwrap();
3056        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3057        assert!(matches!(
3058            err,
3059            TypeError::UnknownVar(name) if name.as_ref() == "Nope"
3060        ));
3061    }
3062
3063    #[test]
3064    fn infer_ambiguous_overload_error() {
3065        let mut ts = TypeSystem::new();
3066        let a = TypeVar::new(0, Some(sym("a")));
3067        let b = TypeVar::new(1, Some(sym("b")));
3068        let scheme_a = Scheme::new(vec![a.clone()], vec![], Type::var(a));
3069        let scheme_b = Scheme::new(vec![b.clone()], vec![], Type::var(b));
3070        ts.add_overload(sym("dup"), scheme_a);
3071        ts.add_overload(sym("dup"), scheme_b);
3072        let expr = parse_expr("dup");
3073        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3074        assert!(matches!(
3075            err,
3076            TypeError::AmbiguousOverload(name) if name.as_ref() == "dup"
3077        ));
3078    }
3079
3080    #[test]
3081    fn infer_if_cond_not_bool_error() {
3082        let expr = parse_expr("if 1 then 2 else 3");
3083        let mut ts = TypeSystem::new_with_prelude().unwrap();
3084        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3085        match err {
3086            TypeError::Unification(a, b) => {
3087                let ok = (a == "bool" && b == "i32") || (a == "i32" && b == "bool");
3088                assert!(ok, "expected bool vs i32, got {a} vs {b}");
3089            }
3090            other => panic!("expected unification error, got {other:?}"),
3091        }
3092    }
3093
3094    #[test]
3095    fn infer_apply_non_function_error() {
3096        let expr = parse_expr("1 2");
3097        let mut ts = TypeSystem::new_with_prelude().unwrap();
3098        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3099        assert!(matches!(err, TypeError::Unification(_, _)));
3100    }
3101
3102    #[test]
3103    fn infer_list_element_mismatch_error() {
3104        let expr = parse_expr("[1, true]");
3105        let mut ts = TypeSystem::new_with_prelude().unwrap();
3106        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3107        match err {
3108            TypeError::Unification(a, b) => {
3109                let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
3110                assert!(ok, "expected i32 vs bool, got {a} vs {b}");
3111            }
3112            other => panic!("expected unification error, got {other:?}"),
3113        }
3114    }
3115
3116    #[test]
3117    fn infer_dict_value_mismatch_error() {
3118        let expr = parse_expr("{a = 1, b = true}");
3119        let mut ts = TypeSystem::new_with_prelude().unwrap();
3120        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3121        match err {
3122            TypeError::Unification(a, b) => {
3123                let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
3124                assert!(ok, "expected i32 vs bool, got {a} vs {b}");
3125            }
3126            other => panic!("expected unification error, got {other:?}"),
3127        }
3128    }
3129
3130    #[test]
3131    fn infer_match_list_on_non_list_error() {
3132        let expr = parse_expr("match 1 when [x] -> x");
3133        let mut ts = TypeSystem::new_with_prelude().unwrap();
3134        assert!(infer(&mut ts, expr.as_ref()).is_err());
3135    }
3136
3137    #[test]
3138    fn infer_pattern_constructor_arity_error() {
3139        let expr = parse_expr("match (Ok 1) when Ok x y -> x");
3140        let mut ts = TypeSystem::new_with_prelude().unwrap();
3141        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3142        assert!(matches!(
3143            err,
3144            TypeError::UnsupportedExpr("pattern constructor")
3145        ));
3146    }
3147
3148    #[test]
3149    fn infer_match_arm_type_mismatch_error() {
3150        let expr = parse_expr(r#"match 1 when _ -> 1 when _ -> "no""#);
3151        let mut ts = TypeSystem::new_with_prelude().unwrap();
3152        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3153        match err {
3154            TypeError::Unification(a, b) => {
3155                let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
3156                assert!(ok, "expected i32 vs string, got {a} vs {b}");
3157            }
3158            other => panic!("expected unification error, got {other:?}"),
3159        }
3160    }
3161
3162    #[test]
3163    fn infer_match_option_on_non_option_error() {
3164        let expr = parse_expr("match 1 when Some x -> x");
3165        let mut ts = TypeSystem::new_with_prelude().unwrap();
3166        assert!(infer(&mut ts, expr.as_ref()).is_err());
3167    }
3168
3169    #[test]
3170    fn infer_dict_pattern_on_non_dict_error() {
3171        let expr = parse_expr("match 1 when {a} -> a");
3172        let mut ts = TypeSystem::new_with_prelude().unwrap();
3173        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3174        assert!(matches!(err, TypeError::Unification(_, _)));
3175    }
3176
3177    #[test]
3178    fn infer_cons_pattern_on_non_list_error() {
3179        let expr = parse_expr("match 1 when x::xs -> x");
3180        let mut ts = TypeSystem::new_with_prelude().unwrap();
3181        assert!(infer(&mut ts, expr.as_ref()).is_err());
3182    }
3183
3184    #[test]
3185    fn infer_apply_wrong_arg_type_error() {
3186        let expr = parse_expr("(\\x -> x + 1) true");
3187        let mut ts = TypeSystem::new_with_prelude().unwrap();
3188        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3189        assert!(matches!(err, TypeError::Unification(_, _)));
3190    }
3191
3192    #[test]
3193    fn infer_self_application_occurs_error() {
3194        let expr = parse_expr("\\x -> x x");
3195        let mut ts = TypeSystem::new_with_prelude().unwrap();
3196        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3197        assert!(matches!(err, TypeError::Occurs(_, _)));
3198    }
3199
3200    #[test]
3201    fn infer_apply_constructor_too_many_args_error() {
3202        let expr = parse_expr("Some 1 2");
3203        let mut ts = TypeSystem::new_with_prelude().unwrap();
3204        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3205        assert!(matches!(err, TypeError::Unification(_, _)));
3206    }
3207
3208    #[test]
3209    fn infer_operator_type_mismatch_error() {
3210        let expr = parse_expr("1 + true");
3211        let mut ts = TypeSystem::new_with_prelude().unwrap();
3212        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3213        assert!(matches!(err, TypeError::Unification(_, _)));
3214    }
3215
3216    #[test]
3217    fn infer_non_exhaustive_match_is_error() {
3218        let expr = parse_expr("match (Ok 1) when Ok x -> x");
3219        let mut ts = TypeSystem::new_with_prelude().unwrap();
3220        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3221        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3222    }
3223
3224    #[test]
3225    fn infer_non_exhaustive_match_on_bound_var_error() {
3226        let expr = parse_expr("let x = Ok 1 in match x when Ok y -> y");
3227        let mut ts = TypeSystem::new_with_prelude().unwrap();
3228        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3229        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3230    }
3231
3232    #[test]
3233    fn infer_non_exhaustive_match_in_lambda_error() {
3234        let expr = parse_expr("\\x -> match x when Ok y -> y");
3235        let mut ts = TypeSystem::new_with_prelude().unwrap();
3236        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3237        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3238    }
3239
3240    #[test]
3241    fn infer_non_exhaustive_option_match_error() {
3242        let expr = parse_expr("match (Some 1) when Some x -> x");
3243        let mut ts = TypeSystem::new_with_prelude().unwrap();
3244        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3245        match err {
3246            TypeError::NonExhaustiveMatch { missing, .. } => {
3247                assert_eq!(missing, vec![sym("None")]);
3248            }
3249            other => panic!("expected non-exhaustive match, got {other:?}"),
3250        }
3251    }
3252
3253    #[test]
3254    fn infer_non_exhaustive_result_match_error() {
3255        let expr = parse_expr("match (Err 1) when Ok x -> x");
3256        let mut ts = TypeSystem::new_with_prelude().unwrap();
3257        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3258        match err {
3259            TypeError::NonExhaustiveMatch { missing, .. } => {
3260                assert_eq!(missing, vec![sym("Err")]);
3261            }
3262            other => panic!("expected non-exhaustive match, got {other:?}"),
3263        }
3264    }
3265
3266    #[test]
3267    fn infer_non_exhaustive_list_missing_empty_error() {
3268        let expr = parse_expr("match [1, 2] when x::xs -> x");
3269        let mut ts = TypeSystem::new_with_prelude().unwrap();
3270        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3271        match err {
3272            TypeError::NonExhaustiveMatch { missing, .. } => {
3273                assert_eq!(missing, vec![sym("Empty")]);
3274            }
3275            other => panic!("expected non-exhaustive match, got {other:?}"),
3276        }
3277    }
3278
3279    #[test]
3280    fn infer_non_exhaustive_list_match_on_bound_var_error() {
3281        let expr = parse_expr("let xs = [1, 2] in match xs when x::xs -> x");
3282        let mut ts = TypeSystem::new_with_prelude().unwrap();
3283        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3284        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3285    }
3286
3287    #[test]
3288    fn infer_non_exhaustive_list_missing_cons_error() {
3289        let expr = parse_expr("match [1] when [] -> 0");
3290        let mut ts = TypeSystem::new_with_prelude().unwrap();
3291        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3292        match err {
3293            TypeError::NonExhaustiveMatch { missing, .. } => {
3294                assert_eq!(missing, vec![sym("Cons")]);
3295            }
3296            other => panic!("expected non-exhaustive match, got {other:?}"),
3297        }
3298    }
3299
3300    #[test]
3301    fn infer_match_list_patterns_on_result_error() {
3302        let expr = parse_expr("match (Ok 1) when [] -> 0 when x::xs -> 1");
3303        let mut ts = TypeSystem::new_with_prelude().unwrap();
3304        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3305        assert!(matches!(err, TypeError::Unification(_, _)));
3306    }
3307
3308    #[test]
3309    fn infer_missing_instances_produce_unsatisfied_predicates() {
3310        for (name, code) in [
3311            ("division", "1 / 2"),
3312            ("eq_dict", "{a = 1} == {a = 2}"),
3313            ("min_bool", "min [true]"),
3314            ("map_dict", r#"map (\x -> x) {a = 1}"#),
3315        ] {
3316            let (class, pred_type, expected_ty) = match name {
3317                "division" => (
3318                    "Field",
3319                    Type::builtin(BuiltinTypeId::I32),
3320                    Some(Type::builtin(BuiltinTypeId::I32)),
3321                ),
3322                "eq_dict" => ("Eq", dict_of(Type::builtin(BuiltinTypeId::I32)), None),
3323                "min_bool" => ("Ord", Type::builtin(BuiltinTypeId::Bool), None),
3324                "map_dict" => ("Functor", Type::builtin(BuiltinTypeId::Dict), None),
3325                _ => unreachable!("unknown test case {name}"),
3326            };
3327
3328            let expr = parse_expr(code);
3329            let mut ts = TypeSystem::new_with_prelude().unwrap();
3330            let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3331            if let Some(expected) = expected_ty {
3332                assert_eq!(ty, expected, "{name}");
3333            }
3334
3335            let pred = preds
3336                .iter()
3337                .find(|p| p.class.as_ref() == class && p.typ == pred_type)
3338                .unwrap();
3339            assert!(!entails(&ts.classes, &[], pred).unwrap(), "{name}");
3340        }
3341    }
3342
3343    #[test]
3344    fn record_update_single_variant_adt_infers() {
3345        let program = parse_program(
3346            r#"
3347            type Foo = Bar { x: i32, y: i32 }
3348            let
3349              foo: Foo = Bar { x = 1, y = 2 },
3350              bar = { foo with { x = 3 } }
3351            in
3352              bar
3353            "#,
3354        );
3355        let mut ts = TypeSystem::new_with_prelude().unwrap();
3356        ts.register_decls(&program.decls).unwrap();
3357        let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3358        assert_eq!(typ.to_string(), "Foo");
3359    }
3360
3361    #[test]
3362    fn record_update_unknown_field_errors() {
3363        let program = parse_program(
3364            r#"
3365            type Foo = Bar { x: i32 }
3366            let
3367              foo: Foo = Bar { x = 1 }
3368            in
3369              { foo with { y = 2 } }
3370            "#,
3371        );
3372        let mut ts = TypeSystem::new_with_prelude().unwrap();
3373        ts.register_decls(&program.decls).unwrap();
3374        let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
3375        let err = strip_span(err);
3376        assert!(matches!(err, TypeError::UnknownField { .. }));
3377    }
3378
3379    #[test]
3380    fn record_update_requires_refined_variant_for_sum_types() {
3381        let program = parse_program(
3382            r#"
3383            type Foo = Bar { x: i32 } | Baz { x: i32 }
3384            let
3385              f = \ (foo : Foo) -> { foo with { x = 2 } }
3386            in
3387              f (Bar { x = 1 })
3388            "#,
3389        );
3390        let mut ts = TypeSystem::new_with_prelude().unwrap();
3391        ts.register_decls(&program.decls).unwrap();
3392        let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
3393        let err = strip_span(err);
3394        assert!(matches!(err, TypeError::FieldNotKnown { .. }));
3395    }
3396
3397    #[test]
3398    fn record_update_allowed_after_match_refines_variant() {
3399        let program = parse_program(
3400            r#"
3401            type Foo = Bar { x: i32 } | Baz { x: i32 }
3402            let
3403              f = \ (foo : Foo) ->
3404                match foo
3405                  when Bar {x} -> { foo with { x = x + 1 } }
3406                  when Baz {x} -> { foo with { x = x + 2 } }
3407            in
3408              f (Bar { x = 1 })
3409            "#,
3410        );
3411        let mut ts = TypeSystem::new_with_prelude().unwrap();
3412        ts.register_decls(&program.decls).unwrap();
3413        let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3414        assert_eq!(typ.to_string(), "Foo");
3415    }
3416
3417    #[test]
3418    fn record_update_plain_record_type() {
3419        let program = parse_program(
3420            r#"
3421            let
3422              f = \ (r : { x: i32, y: i32 }) -> { r with { y = 9 } }
3423            in
3424              f { x = 1, y = 2 }
3425            "#,
3426        );
3427        let mut ts = TypeSystem::new_with_prelude().unwrap();
3428        ts.register_decls(&program.decls).unwrap();
3429        let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3430        assert_eq!(typ.to_string(), "{x: i32, y: i32}");
3431    }
3432
3433    #[test]
3434    fn infer_typed_hole_expr_is_hole_kind() {
3435        let expr = parse_expr("?");
3436        let mut ts = TypeSystem::new_with_prelude().unwrap();
3437        let (typed, _preds, _ty) = infer_typed(&mut ts, expr.as_ref()).unwrap();
3438        assert!(
3439            matches!(typed.kind, TypedExprKind::Hole),
3440            "typed={typed:#?}"
3441        );
3442    }
3443
3444    #[test]
3445    fn infer_hole_with_annotation_unifies_to_annotation() {
3446        let expr = parse_expr("let x : i32 = ? in x");
3447        let mut ts = TypeSystem::new_with_prelude().unwrap();
3448        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3449        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3450    }
3451
3452    #[test]
3453    fn infer_hole_in_if_condition_is_bool_constrained() {
3454        let expr = parse_expr("if ? then 1 else 2");
3455        let mut ts = TypeSystem::new_with_prelude().unwrap();
3456        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3457        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3458    }
3459
3460    #[test]
3461    fn infer_hole_in_arithmetic_is_numeric_constrained() {
3462        let expr = parse_expr("? + 1");
3463        let mut ts = TypeSystem::new_with_prelude().unwrap();
3464        let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3465        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3466    }
3467
3468    #[test]
3469    fn infer_hole_arithmetic_conflicting_annotation_failure() {
3470        let expr = parse_expr("let x : string = (? + 1) in x");
3471        let mut ts = TypeSystem::new_with_prelude().unwrap();
3472        let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3473        assert!(matches!(err, TypeError::Unification(_, _)), "err={err:#?}");
3474    }
3475}