Skip to main content

rex_typesystem/
inference.rs

1use crate::{
2    error::TypeError,
3    types::{
4        AdtDecl, AdtVariant, BuiltinTypeId, Predicate, Scheme, Type, TypeEnv, TypeKind, TypeVar,
5        TypeVarId, TypedExpr, TypedExprKind, Types,
6    },
7    typesystem::{
8        TypeSystem, TypeVarSupply, instantiate, is_integral_literal_expr,
9        predicates_from_constraints, reject_ambiguous_scheme, type_from_annotation_expr,
10        type_from_annotation_expr_vars,
11    },
12    unification::{Subst, Unifier, compose_subst, subst_is_empty, unify},
13};
14use rex_ast::{Expr, Pattern, Symbol, TypeConstraint, TypeExpr};
15use std::{
16    collections::{BTreeMap, BTreeSet},
17    sync::Arc,
18};
19
20fn dedup_preds(preds: Vec<Predicate>) -> Vec<Predicate> {
21    let mut seen = BTreeSet::new();
22    let mut out = Vec::with_capacity(preds.len());
23    for pred in preds {
24        if seen.insert(pred.clone()) {
25            out.push(pred);
26        }
27    }
28    out
29}
30
31fn is_integral_primitive(typ: &Type) -> bool {
32    matches!(
33        typ.as_ref(),
34        TypeKind::Con(tc)
35            if matches!(
36                tc.builtin_id(),
37                Some(
38                    BuiltinTypeId::U8
39                        | BuiltinTypeId::U16
40                        | BuiltinTypeId::U32
41                        | BuiltinTypeId::U64
42                        | BuiltinTypeId::I8
43                        | BuiltinTypeId::I16
44                        | BuiltinTypeId::I32
45                        | BuiltinTypeId::I64
46                )
47            )
48    )
49}
50
51fn finalize_infer_for_public_api(
52    mut preds: Vec<Predicate>,
53    mut typ: Type,
54) -> Result<(Vec<Predicate>, Type), TypeError> {
55    let mut subst = Subst::new_sync();
56    for pred in &preds {
57        if pred.class.as_ref() == "Integral"
58            && let TypeKind::Var(tv) = pred.typ.as_ref()
59        {
60            subst = subst.insert(tv.id, Type::builtin(BuiltinTypeId::I32));
61        }
62    }
63
64    if !subst_is_empty(&subst) {
65        preds = dedup_preds(preds.apply(&subst));
66        typ = typ.apply(&subst);
67    }
68
69    for pred in &preds {
70        if pred.class.as_ref() != "Integral" {
71            continue;
72        }
73        if matches!(pred.typ.as_ref(), TypeKind::Var(_)) || is_integral_primitive(&pred.typ) {
74            continue;
75        }
76        return Err(TypeError::Unification("i32".into(), pred.typ.to_string()));
77    }
78
79    Ok((preds, typ))
80}
81
82#[derive(Clone, Debug)]
83struct KnownVariant {
84    adt: Symbol,
85    variant: Symbol,
86}
87
88type KnownVariants = BTreeMap<Symbol, KnownVariant>;
89
90fn apply_scheme_with_unifier(scheme: &Scheme, unifier: &mut Unifier) -> Scheme {
91    let preds = scheme
92        .preds
93        .iter()
94        .map(|pred| Predicate::new(pred.class.clone(), unifier.apply_type(&pred.typ)))
95        .collect();
96    let typ = unifier.apply_type(&scheme.typ);
97    Scheme::new(scheme.vars.clone(), preds, typ)
98}
99
100fn scheme_ftv_with_unifier(scheme: &Scheme, unifier: &mut Unifier) -> BTreeSet<TypeVarId> {
101    let mut ftv = unifier.apply_type(&scheme.typ).ftv();
102    for pred in &scheme.preds {
103        ftv.extend(unifier.apply_type(&pred.typ).ftv());
104    }
105    for var in &scheme.vars {
106        ftv.remove(&var.id);
107    }
108    ftv
109}
110
111fn env_ftv_with_unifier(env: &TypeEnv, unifier: &mut Unifier) -> BTreeSet<TypeVarId> {
112    let mut out = BTreeSet::new();
113    for (_name, schemes) in env.values.iter() {
114        for scheme in schemes {
115            out.extend(scheme_ftv_with_unifier(scheme, unifier));
116        }
117    }
118    out
119}
120
121fn generalize_with_unifier(
122    env: &TypeEnv,
123    preds: Vec<Predicate>,
124    typ: Type,
125    unifier: &mut Unifier,
126) -> Scheme {
127    let preds: Vec<Predicate> = preds
128        .into_iter()
129        .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
130        .collect();
131    let typ = unifier.apply_type(&typ);
132    let mut vars: Vec<TypeVar> = typ
133        .ftv()
134        .union(&preds.ftv())
135        .copied()
136        .collect::<BTreeSet<_>>()
137        .difference(&env_ftv_with_unifier(env, unifier))
138        .cloned()
139        .map(|id| TypeVar::new(id, None))
140        .collect();
141    vars.sort_by_key(|v| v.id);
142    Scheme::new(vars, preds, typ)
143}
144
145fn monomorphic_scheme_with_unifier(
146    preds: Vec<Predicate>,
147    typ: Type,
148    unifier: &mut Unifier,
149) -> Scheme {
150    let preds = dedup_preds(
151        preds
152            .into_iter()
153            .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
154            .collect(),
155    );
156    let typ = unifier.apply_type(&typ);
157    Scheme::new(vec![], preds, typ)
158}
159
160pub fn infer_typed(
161    type_system: &mut TypeSystem,
162    expr: &Expr,
163) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
164    infer_typed_inner(type_system, expr)
165}
166
167fn infer_typed_inner(
168    type_system: &mut TypeSystem,
169    expr: &Expr,
170) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
171    let known = KnownVariants::new();
172    let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
173    let (preds, t, typed) = infer_expr(
174        &mut unifier,
175        &mut type_system.supply,
176        &type_system.env,
177        &type_system.adts,
178        &known,
179        expr,
180    )
181    .map_err(|err| err.with_span(expr.span()))?;
182    let subst = unifier.into_subst();
183    let mut typed = typed.apply(&subst);
184    let mut preds = dedup_preds(preds.apply(&subst));
185    let mut t = t.apply(&subst);
186    let improve = improve_indexable(&preds)?;
187    if !subst_is_empty(&improve) {
188        typed = typed.apply(&improve);
189        preds = dedup_preds(preds.apply(&improve));
190        t = t.apply(&improve);
191    }
192    type_system.check_predicate_kinds(&preds)?;
193    Ok((typed, preds, t))
194}
195
196pub fn infer(
197    type_system: &mut TypeSystem,
198    expr: &Expr,
199) -> Result<(Vec<Predicate>, Type), TypeError> {
200    infer_inner(type_system, expr)
201}
202
203fn infer_inner(
204    type_system: &mut TypeSystem,
205    expr: &Expr,
206) -> Result<(Vec<Predicate>, Type), TypeError> {
207    let known = KnownVariants::new();
208    let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
209    let (preds, t) = infer_expr_type(
210        &mut unifier,
211        &mut type_system.supply,
212        &type_system.env,
213        &type_system.adts,
214        &known,
215        expr,
216    )
217    .map_err(|err| err.with_span(expr.span()))?;
218    let subst = unifier.into_subst();
219    let mut preds = dedup_preds(preds.apply(&subst));
220    let mut t = t.apply(&subst);
221    let improve = improve_indexable(&preds)?;
222    if !subst_is_empty(&improve) {
223        preds = dedup_preds(preds.apply(&improve));
224        t = t.apply(&improve);
225    }
226    type_system.check_predicate_kinds(&preds)?;
227    finalize_infer_for_public_api(preds, t)
228}
229
230fn improve_indexable(preds: &[Predicate]) -> Result<Subst, TypeError> {
231    let mut subst = Subst::new_sync();
232    loop {
233        let mut changed = false;
234        for pred in preds {
235            let pred = pred.apply(&subst);
236            if pred.class.as_ref() != "Indexable" {
237                continue;
238            }
239            let TypeKind::Tuple(parts) = pred.typ.as_ref() else {
240                continue;
241            };
242            if parts.len() != 2 {
243                continue;
244            }
245            let container = parts[0].clone();
246            let elem = parts[1].clone();
247            let s = indexable_elem_subst(&container, &elem)?;
248            if !subst_is_empty(&s) {
249                subst = compose_subst(s, subst);
250                changed = true;
251            }
252        }
253        if !changed {
254            break;
255        }
256    }
257    Ok(subst)
258}
259
260fn indexable_elem_subst(container: &Type, elem: &Type) -> Result<Subst, TypeError> {
261    match container.as_ref() {
262        TypeKind::App(head, arg) => match head.as_ref() {
263            TypeKind::Con(tc)
264                if matches!(
265                    tc.builtin_id(),
266                    Some(BuiltinTypeId::List | BuiltinTypeId::Array)
267                ) =>
268            {
269                unify(elem, arg)
270            }
271            _ => Ok(Subst::new_sync()),
272        },
273        TypeKind::Tuple(elems) => {
274            if elems.is_empty() {
275                return Ok(Subst::new_sync());
276            }
277            let mut subst = Subst::new_sync();
278            let mut cur = elems[0].clone();
279            for ty in elems.iter().skip(1) {
280                let s_next = unify(&cur.apply(&subst), &ty.apply(&subst))?;
281                subst = compose_subst(s_next, subst);
282                cur = cur.apply(&subst);
283            }
284            let elem = elem.apply(&subst);
285            let s_elem = unify(&elem, &cur.apply(&subst))?;
286            Ok(compose_subst(s_elem, subst))
287        }
288        _ => Ok(Subst::new_sync()),
289    }
290}
291
292type LambdaChain<'a> = (
293    Vec<(Symbol, Option<TypeExpr>)>,
294    Vec<TypeConstraint>,
295    &'a Expr,
296);
297
298fn collect_lambda_chain<'a>(expr: &'a Expr) -> LambdaChain<'a> {
299    let mut params = Vec::new();
300    let mut constraints = Vec::new();
301    let mut cur = expr;
302    let mut seen_constraints = false;
303    while let Expr::Lam(_, _scope, param, ann, lam_constraints, body) = cur {
304        if !lam_constraints.is_empty() {
305            if seen_constraints {
306                break;
307            }
308            constraints = lam_constraints.clone();
309            seen_constraints = true;
310        }
311        params.push((param.name.clone(), ann.clone()));
312        cur = body.as_ref();
313    }
314    (params, constraints, cur)
315}
316
317fn collect_app_chain(expr: &Expr) -> (&Expr, Vec<&Expr>) {
318    let mut args = Vec::new();
319    let mut cur = expr;
320    while let Expr::App(_, f, x) = cur {
321        args.push(x.as_ref());
322        cur = f.as_ref();
323    }
324    args.reverse();
325    (cur, args)
326}
327
328fn narrow_overload_candidates(candidates: &[Type], arg_ty: &Type) -> Vec<Type> {
329    let mut out = Vec::new();
330    for candidate in candidates {
331        let Some((params, ret)) = decompose_fun(candidate, 1) else {
332            continue;
333        };
334        let param = &params[0];
335        if let Ok(s) = unify(param, arg_ty) {
336            out.push(ret.apply(&s));
337        }
338    }
339    out
340}
341
342fn unary_app_arg(typ: &Type, ctor_name: &str) -> Option<Type> {
343    let TypeKind::App(head, arg) = typ.as_ref() else {
344        return None;
345    };
346    let TypeKind::Con(tc) = head.as_ref() else {
347        return None;
348    };
349    (tc.name_str() == ctor_name && tc.arity() == 1).then(|| arg.clone())
350}
351
352fn infer_app_arg_type(
353    unifier: &mut Unifier,
354    supply: &mut TypeVarSupply,
355    env: &TypeEnv,
356    adts: &BTreeMap<Symbol, AdtDecl>,
357    known: &KnownVariants,
358    arg_hint: Option<Type>,
359    arg: &Expr,
360) -> Result<(Vec<Predicate>, Type), TypeError> {
361    match (arg_hint, arg) {
362        (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
363            infer_record_update_type_with_hint(
364                unifier,
365                supply,
366                env,
367                adts,
368                known,
369                base.as_ref(),
370                updates,
371                &arg_hint,
372            )
373        }
374        (Some(arg_hint), Expr::Dict(_, kvs))
375            if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
376        {
377            let TypeKind::Record(fields) = arg_hint.as_ref() else {
378                unreachable!("guarded by matches!")
379            };
380            let expected: BTreeMap<_, _> =
381                fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
382            let mut seen = BTreeSet::new();
383            let mut preds = Vec::new();
384            for (k, v) in kvs {
385                let expected_ty = expected
386                    .get(k)
387                    .ok_or_else(|| TypeError::UnknownField {
388                        field: k.clone(),
389                        typ: Type::record(fields.clone()).to_string(),
390                    })?
391                    .clone();
392                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
393                unifier.unify(&t1, &expected_ty)?;
394                preds.extend(p1);
395                seen.insert(k.clone());
396            }
397            for key in expected.keys() {
398                if !seen.contains(key.as_ref()) {
399                    return Err(TypeError::UnknownField {
400                        field: key.clone(),
401                        typ: Type::record(fields.clone()).to_string(),
402                    });
403                }
404            }
405            let record_ty = Type::record(
406                fields
407                    .iter()
408                    .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
409                    .collect(),
410            );
411            Ok((preds, record_ty))
412        }
413        _ => infer_expr_type(unifier, supply, env, adts, known, arg),
414    }
415}
416
417fn infer_app_arg_typed(
418    unifier: &mut Unifier,
419    supply: &mut TypeVarSupply,
420    env: &TypeEnv,
421    adts: &BTreeMap<Symbol, AdtDecl>,
422    known: &KnownVariants,
423    arg_hint: Option<Type>,
424    arg: &Expr,
425) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
426    match (arg_hint, arg) {
427        (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
428            infer_record_update_typed_with_hint(
429                unifier,
430                supply,
431                env,
432                adts,
433                known,
434                base.as_ref(),
435                updates,
436                &arg_hint,
437            )
438        }
439        (Some(arg_hint), Expr::Dict(_, kvs))
440            if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
441        {
442            let TypeKind::Record(fields) = arg_hint.as_ref() else {
443                unreachable!("guarded by matches!")
444            };
445            let mut preds = Vec::new();
446            let mut typed_kvs = BTreeMap::new();
447            let expected: BTreeMap<_, _> =
448                fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
449            for (k, v) in kvs {
450                let expected_ty = expected
451                    .get(k)
452                    .ok_or_else(|| TypeError::UnknownField {
453                        field: k.clone(),
454                        typ: Type::record(fields.clone()).to_string(),
455                    })?
456                    .clone();
457                let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
458                unifier.unify(&t1, &expected_ty)?;
459                preds.extend(p1);
460                typed_kvs.insert(k.clone(), Arc::new(typed_v));
461            }
462            for key in expected.keys() {
463                if !typed_kvs.contains_key(key.as_ref()) {
464                    return Err(TypeError::UnknownField {
465                        field: key.clone(),
466                        typ: Type::record(fields.clone()).to_string(),
467                    });
468                }
469            }
470            let record_ty = Type::record(
471                fields
472                    .iter()
473                    .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
474                    .collect(),
475            );
476            let typed = TypedExpr::new(record_ty.clone(), TypedExprKind::Dict(typed_kvs));
477            Ok((preds, record_ty, typed))
478        }
479        _ => infer_expr(unifier, supply, env, adts, known, arg),
480    }
481}
482
483struct TypedAppState {
484    preds: Vec<Predicate>,
485    func_ty: Type,
486    typed: TypedExpr,
487    overload_name: Option<Symbol>,
488    overload_candidates: Option<Vec<Type>>,
489}
490
491struct TailAppFrame<'a> {
492    head: &'a Expr,
493    prefix_args: Vec<&'a Expr>,
494}
495
496fn app_arg_hint(unifier: &mut Unifier, func_ty: &Type) -> Option<Type> {
497    match unifier.apply_type(func_ty).as_ref() {
498        TypeKind::Fun(arg, _) => Some(arg.clone()),
499        _ => None,
500    }
501}
502
503#[allow(clippy::too_many_arguments)]
504fn infer_typed_app_head(
505    unifier: &mut Unifier,
506    supply: &mut TypeVarSupply,
507    env: &TypeEnv,
508    adts: &BTreeMap<Symbol, AdtDecl>,
509    known: &KnownVariants,
510    head: &Expr,
511) -> Result<TypedAppState, TypeError> {
512    let (preds, func_ty, typed) = infer_expr(unifier, supply, env, adts, known, head)?;
513    let mut overload_name = None;
514    let overload_candidates = match typed.kind.as_ref() {
515        TypedExprKind::Var { name, overloads } if !overloads.is_empty() => {
516            overload_name = Some(name.clone());
517            Some(overloads.clone())
518        }
519        _ => None,
520    };
521    Ok(TypedAppState {
522        preds,
523        func_ty,
524        typed,
525        overload_name,
526        overload_candidates,
527    })
528}
529
530fn apply_typed_app_arg(
531    unifier: &mut Unifier,
532    supply: &mut TypeVarSupply,
533    state: &mut TypedAppState,
534    expected_arg: Option<Type>,
535    p_arg: Vec<Predicate>,
536    arg_ty: Type,
537    typed_arg: TypedExpr,
538) -> Result<(), TypeError> {
539    let mut arg_ty = unifier.apply_type(&arg_ty);
540    let mut typed_arg = typed_arg;
541
542    if let Some(expected_arg) = expected_arg {
543        let expected_arg = unifier.apply_type(&expected_arg);
544        if let (Some(expected_elem), Some(arg_elem)) = (
545            unary_app_arg(&expected_arg, "Array"),
546            unary_app_arg(&arg_ty, "List"),
547        ) {
548            unifier.unify(&expected_elem, &arg_elem)?;
549            let elem_ty = unifier.apply_type(&expected_elem);
550            let list_ty = Type::list(elem_ty.clone());
551            let array_ty = Type::array(elem_ty);
552            let coercion_ty = Type::fun(list_ty, array_ty.clone());
553            let coercion_fn = TypedExpr::new(
554                coercion_ty,
555                TypedExprKind::Var {
556                    name: Symbol::intern("prim_array_from_list"),
557                    overloads: vec![],
558                },
559            );
560            typed_arg = TypedExpr::new(
561                array_ty.clone(),
562                TypedExprKind::App(Arc::new(coercion_fn), Arc::new(typed_arg)),
563            );
564            arg_ty = array_ty;
565        }
566    }
567    if let Some(candidates) = state.overload_candidates.take() {
568        let candidates = candidates
569            .into_iter()
570            .map(|t| unifier.apply_type(&t))
571            .collect::<Vec<_>>();
572        let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
573        if narrowed.is_empty()
574            && let Some(name) = &state.overload_name
575        {
576            return Err(TypeError::AmbiguousOverload(name.clone()));
577        }
578        state.overload_candidates = Some(narrowed);
579    }
580    let res_ty = match state.overload_candidates.as_ref() {
581        Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
582        _ => Type::var(supply.fresh(Some(Symbol::intern("r")))),
583    };
584    unifier.unify(&state.func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
585    let result_ty = match state.overload_candidates.as_ref() {
586        Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
587        _ => unifier.apply_type(&res_ty),
588    };
589    state.preds.extend(p_arg);
590    state.typed = TypedExpr::new(
591        result_ty.clone(),
592        TypedExprKind::App(Arc::new(state.typed.clone()), Arc::new(typed_arg)),
593    );
594    state.func_ty = result_ty;
595    Ok(())
596}
597
598#[allow(clippy::too_many_arguments)]
599fn infer_typed_app_expr_arg(
600    unifier: &mut Unifier,
601    supply: &mut TypeVarSupply,
602    env: &TypeEnv,
603    adts: &BTreeMap<Symbol, AdtDecl>,
604    known: &KnownVariants,
605    state: &mut TypedAppState,
606    arg: &Expr,
607) -> Result<(), TypeError> {
608    let expected_arg = app_arg_hint(unifier, &state.func_ty);
609    let (p_arg, arg_ty, typed_arg) =
610        infer_app_arg_typed(unifier, supply, env, adts, known, expected_arg.clone(), arg)?;
611    apply_typed_app_arg(
612        unifier,
613        supply,
614        state,
615        expected_arg,
616        p_arg,
617        arg_ty,
618        typed_arg,
619    )
620}
621
622fn collect_tail_app_chain(expr: &Expr) -> Option<(&Expr, Vec<TailAppFrame<'_>>)> {
623    let mut frames = Vec::new();
624    let mut cur = expr;
625    while let Expr::App(..) = cur {
626        let (head, mut args) = collect_app_chain(cur);
627        let Some(tail) = args.pop() else {
628            break;
629        };
630        if !matches!(tail, Expr::App(..)) {
631            break;
632        }
633        frames.push(TailAppFrame {
634            head,
635            prefix_args: args,
636        });
637        cur = tail;
638    }
639    (!frames.is_empty()).then_some((cur, frames))
640}
641
642#[allow(clippy::too_many_arguments)]
643fn infer_tail_app_chain_typed(
644    unifier: &mut Unifier,
645    supply: &mut TypeVarSupply,
646    env: &TypeEnv,
647    adts: &BTreeMap<Symbol, AdtDecl>,
648    known: &KnownVariants,
649    leaf: &Expr,
650    frames: Vec<TailAppFrame<'_>>,
651) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
652    let (mut preds, mut tail_ty, mut typed_tail) =
653        infer_expr(unifier, supply, env, adts, known, leaf)?;
654
655    for frame in frames.into_iter().rev() {
656        let mut state = infer_typed_app_head(unifier, supply, env, adts, known, frame.head)?;
657        for arg in frame.prefix_args {
658            infer_typed_app_expr_arg(unifier, supply, env, adts, known, &mut state, arg)?;
659        }
660        let expected_arg = app_arg_hint(unifier, &state.func_ty);
661        apply_typed_app_arg(
662            unifier,
663            supply,
664            &mut state,
665            expected_arg,
666            Vec::new(),
667            tail_ty,
668            typed_tail,
669        )?;
670        preds.extend(state.preds);
671        tail_ty = state.func_ty;
672        typed_tail = state.typed;
673    }
674
675    Ok((preds, tail_ty, typed_tail))
676}
677
678#[allow(clippy::too_many_arguments)]
679fn infer_record_update_type_with_hint(
680    unifier: &mut Unifier,
681    supply: &mut TypeVarSupply,
682    env: &TypeEnv,
683    adts: &BTreeMap<Symbol, AdtDecl>,
684    known: &KnownVariants,
685    base: &Expr,
686    updates: &BTreeMap<Symbol, Arc<Expr>>,
687    hint_ty: &Type,
688) -> Result<(Vec<Predicate>, Type), TypeError> {
689    let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
690    unifier.unify(&t_base, hint_ty)?;
691    let base_ty = unifier.apply_type(&t_base);
692    let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
693    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
694    let (result_ty, fields) = resolve_record_update(
695        unifier,
696        supply,
697        adts,
698        &base_ty,
699        known_variant,
700        &update_fields,
701    )?;
702    let expected: BTreeMap<_, _> = fields.into_iter().collect();
703
704    let mut preds = p_base;
705    for (k, v) in updates {
706        let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
707            field: k.clone(),
708            typ: result_ty.to_string(),
709        })?;
710        let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
711        unifier.unify(&t1, expected_ty)?;
712        preds.extend(p1);
713    }
714    Ok((preds, result_ty))
715}
716
717#[allow(clippy::too_many_arguments)]
718fn infer_record_update_typed_with_hint(
719    unifier: &mut Unifier,
720    supply: &mut TypeVarSupply,
721    env: &TypeEnv,
722    adts: &BTreeMap<Symbol, AdtDecl>,
723    known: &KnownVariants,
724    base: &Expr,
725    updates: &BTreeMap<Symbol, Arc<Expr>>,
726    hint_ty: &Type,
727) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
728    let (p_base, t_base, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
729    unifier.unify(&t_base, hint_ty)?;
730    let base_ty = unifier.apply_type(&t_base);
731    let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
732    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
733    let (result_ty, fields) = resolve_record_update(
734        unifier,
735        supply,
736        adts,
737        &base_ty,
738        known_variant,
739        &update_fields,
740    )?;
741    let expected: BTreeMap<_, _> = fields.into_iter().collect();
742
743    let mut preds = p_base;
744    let mut typed_updates = BTreeMap::new();
745    for (k, v) in updates {
746        let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
747            field: k.clone(),
748            typ: result_ty.to_string(),
749        })?;
750        let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
751        unifier.unify(&t1, expected_ty)?;
752        preds.extend(p1);
753        typed_updates.insert(k.clone(), Arc::new(typed_v));
754    }
755
756    let typed = TypedExpr::new(
757        result_ty.clone(),
758        TypedExprKind::RecordUpdate {
759            base: Arc::new(typed_base),
760            updates: typed_updates,
761        },
762    );
763    Ok((preds, result_ty, typed))
764}
765
766fn infer_expr_type(
767    unifier: &mut Unifier,
768    supply: &mut TypeVarSupply,
769    env: &TypeEnv,
770    adts: &BTreeMap<Symbol, AdtDecl>,
771    known: &KnownVariants,
772    expr: &Expr,
773) -> Result<(Vec<Predicate>, Type), TypeError> {
774    let span = *expr.span();
775    let res = unifier.with_infer_depth(span, |unifier| {
776        infer_expr_type_inner(unifier, supply, env, adts, known, expr)
777    });
778    res.map_err(|err| err.with_span(&span))
779}
780
781fn infer_expr_type_inner(
782    unifier: &mut Unifier,
783    supply: &mut TypeVarSupply,
784    env: &TypeEnv,
785    adts: &BTreeMap<Symbol, AdtDecl>,
786    known: &KnownVariants,
787    expr: &Expr,
788) -> Result<(Vec<Predicate>, Type), TypeError> {
789    match expr {
790        Expr::Bool(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Bool))),
791        Expr::Uint(_, _) => {
792            let lit_ty = Type::var(supply.fresh(Some(Symbol::intern("n"))));
793            Ok((vec![Predicate::new("Integral", lit_ty.clone())], lit_ty))
794        }
795        Expr::Int(_, _) => {
796            let lit_ty = Type::var(supply.fresh(Some(Symbol::intern("n"))));
797            Ok((
798                vec![
799                    Predicate::new("Integral", lit_ty.clone()),
800                    Predicate::new("AdditiveGroup", lit_ty.clone()),
801                ],
802                lit_ty,
803            ))
804        }
805        Expr::Float(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::F32))),
806        Expr::String(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::String))),
807        Expr::Uuid(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Uuid))),
808        Expr::DateTime(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::DateTime))),
809        Expr::Hole(_) => {
810            let t = Type::var(supply.fresh(Some(Symbol::intern("hole"))));
811            Ok((vec![], t))
812        }
813        Expr::Var(var) => {
814            let schemes = env
815                .lookup(&var.name)
816                .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
817            if schemes.len() == 1 {
818                let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
819                let (preds, t) = instantiate(&scheme, supply);
820                Ok((preds, t))
821            } else {
822                for scheme in schemes {
823                    if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
824                        return Err(TypeError::AmbiguousOverload(var.name.clone()));
825                    }
826                }
827                let t = Type::var(supply.fresh(Some(var.name.clone())));
828                Ok((vec![], t))
829            }
830        }
831        Expr::Lam(..) => {
832            let (params, constraints, body) = collect_lambda_chain(expr);
833            let mut ann_vars = BTreeMap::new();
834            let mut param_tys = Vec::with_capacity(params.len());
835            for (name, ann) in &params {
836                let param_ty = match ann {
837                    Some(ann) => type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?,
838                    None => Type::var(supply.fresh(Some(name.clone()))),
839                };
840                param_tys.push((name.clone(), param_ty));
841            }
842
843            let mut env1 = env.clone();
844            let mut known_body = known.clone();
845            for (name, param_ty) in &param_tys {
846                env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
847                known_body.remove(name);
848            }
849
850            let (mut preds, body_ty) =
851                infer_expr_type(unifier, supply, &env1, adts, &known_body, body)?;
852            let constraint_preds =
853                predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
854            preds.extend(constraint_preds);
855
856            let mut fun_ty = unifier.apply_type(&body_ty);
857            for (_, param_ty) in param_tys.iter().rev() {
858                fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
859            }
860            Ok((preds, fun_ty))
861        }
862        Expr::App(..) => {
863            let (head, args) = collect_app_chain(expr);
864            let (mut preds, mut func_ty) =
865                infer_expr_type(unifier, supply, env, adts, known, head)?;
866            let mut overload_name = None;
867            let mut overload_candidates = if let Expr::Var(var) = head {
868                if let Some(schemes) = env.lookup(&var.name) {
869                    if schemes.len() <= 1 {
870                        None
871                    } else {
872                        let mut candidates = Vec::new();
873                        for scheme in schemes {
874                            if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
875                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
876                            }
877                            let scheme = apply_scheme_with_unifier(scheme, unifier);
878                            let (p, typ) = instantiate(&scheme, supply);
879                            if !p.is_empty() {
880                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
881                            }
882                            candidates.push(typ);
883                        }
884                        overload_name = Some(var.name.clone());
885                        Some(candidates)
886                    }
887                } else {
888                    None
889                }
890            } else {
891                None
892            };
893            for arg in args {
894                let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
895                    TypeKind::Fun(arg, _) => Some(arg.clone()),
896                    _ => None,
897                };
898                let (p_arg, arg_ty) =
899                    infer_app_arg_type(unifier, supply, env, adts, known, arg_hint, arg)?;
900                let arg_ty = unifier.apply_type(&arg_ty);
901                if let Some(candidates) = overload_candidates.take() {
902                    let candidates = candidates
903                        .into_iter()
904                        .map(|t| unifier.apply_type(&t))
905                        .collect::<Vec<_>>();
906                    let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
907                    if narrowed.is_empty()
908                        && let Some(name) = &overload_name
909                    {
910                        return Err(TypeError::AmbiguousOverload(name.clone()));
911                    }
912                    overload_candidates = Some(narrowed);
913                }
914                let res_ty = match overload_candidates.as_ref() {
915                    Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
916                    _ => Type::var(supply.fresh(Some(Symbol::intern("r")))),
917                };
918                unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
919                preds.extend(p_arg);
920                func_ty = match overload_candidates.as_ref() {
921                    Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
922                    _ => unifier.apply_type(&res_ty),
923                };
924            }
925            Ok((preds, func_ty))
926        }
927        Expr::Project(_, base, field) => {
928            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, base)?;
929            let base_ty = unifier.apply_type(&t1);
930            let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
931            let field_ty =
932                resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
933            Ok((p1, field_ty))
934        }
935        Expr::RecordUpdate(_, base, updates) => {
936            let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
937            let base_ty = unifier.apply_type(&t_base);
938            let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
939            let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
940            let (result_ty, fields) = resolve_record_update(
941                unifier,
942                supply,
943                adts,
944                &base_ty,
945                known_variant,
946                &update_fields,
947            )?;
948            let expected: BTreeMap<_, _> = fields.into_iter().collect();
949
950            let mut preds = p_base;
951            for (k, v) in updates {
952                let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
953                    field: k.clone(),
954                    typ: result_ty.to_string(),
955                })?;
956                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
957                unifier.unify(&t1, expected_ty)?;
958                preds.extend(p1);
959            }
960            Ok((preds, result_ty))
961        }
962        Expr::Let(..) => {
963            let mut bindings = Vec::new();
964            let mut cur = expr;
965            while let Expr::Let(_, v, ann, d, b) = cur {
966                bindings.push((v.clone(), ann.clone(), d.clone()));
967                cur = b.as_ref();
968            }
969
970            let mut env_cur = env.clone();
971            let mut known_cur = known.clone();
972            for (v, ann, d) in bindings {
973                let (p1, t1) = if let Some(ref ann_expr) = ann {
974                    let mut ann_vars = BTreeMap::new();
975                    let ann_ty =
976                        type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
977                    match d.as_ref() {
978                        Expr::RecordUpdate(_, base, updates) => infer_record_update_type_with_hint(
979                            unifier,
980                            supply,
981                            &env_cur,
982                            adts,
983                            &known_cur,
984                            base.as_ref(),
985                            updates,
986                            &ann_ty,
987                        )?,
988                        _ => {
989                            let (p1, t1) =
990                                infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?;
991                            unifier.unify(&t1, &ann_ty)?;
992                            (p1, t1)
993                        }
994                    }
995                } else {
996                    infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?
997                };
998                let def_ty = unifier.apply_type(&t1);
999                let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
1000                    monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
1001                } else {
1002                    let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
1003                    reject_ambiguous_scheme(&scheme)?;
1004                    scheme
1005                };
1006                env_cur.extend(v.name.clone(), scheme);
1007                if let Some(known_variant) =
1008                    known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
1009                {
1010                    known_cur.insert(
1011                        v.name.clone(),
1012                        KnownVariant {
1013                            adt: known_variant.adt,
1014                            variant: known_variant.variant,
1015                        },
1016                    );
1017                } else {
1018                    known_cur.remove(&v.name);
1019                }
1020            }
1021
1022            let (p_body, t_body) =
1023                infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, cur)?;
1024            Ok((p_body, t_body))
1025        }
1026        Expr::LetRec(_, bindings, body) => {
1027            let mut env_seed = env.clone();
1028            let mut known_seed = known.clone();
1029            let mut binding_tys = BTreeMap::new();
1030            for (var, _ann, _def) in bindings {
1031                let tv = Type::var(supply.fresh(Some(var.name.clone())));
1032                env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
1033                known_seed.remove(&var.name);
1034                binding_tys.insert(var.name.clone(), tv);
1035            }
1036
1037            let mut inferred = Vec::with_capacity(bindings.len());
1038            for (var, ann, def) in bindings {
1039                let (preds, def_ty) =
1040                    infer_expr_type(unifier, supply, &env_seed, adts, &known_seed, def)?;
1041                if let Some(ann) = ann {
1042                    let mut ann_vars = BTreeMap::new();
1043                    let ann_ty = type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
1044                    unifier.unify(&def_ty, &ann_ty)?;
1045                }
1046                let binding_ty = binding_tys
1047                    .get(&var.name)
1048                    .cloned()
1049                    .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1050                unifier.unify(&binding_ty, &def_ty)?;
1051                let resolved_ty = unifier.apply_type(&binding_ty);
1052
1053                if let Some(known_variant) =
1054                    known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
1055                {
1056                    known_seed.insert(
1057                        var.name.clone(),
1058                        KnownVariant {
1059                            adt: known_variant.adt,
1060                            variant: known_variant.variant,
1061                        },
1062                    );
1063                } else {
1064                    known_seed.remove(&var.name);
1065                }
1066                inferred.push((var.name.clone(), preds, resolved_ty));
1067            }
1068
1069            let mut env_body = env.clone();
1070            for (name, preds, def_ty) in inferred {
1071                let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
1072                reject_ambiguous_scheme(&scheme)?;
1073                env_body.extend(name, scheme);
1074            }
1075
1076            let (p_body, t_body) =
1077                infer_expr_type(unifier, supply, &env_body, adts, &known_seed, body)?;
1078            Ok((p_body, t_body))
1079        }
1080        Expr::Ite(_, cond, then_expr, else_expr) => {
1081            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, cond)?;
1082            unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
1083            let (p2, t2) = infer_expr_type(unifier, supply, env, adts, known, then_expr)?;
1084            let (p3, t3) = infer_expr_type(unifier, supply, env, adts, known, else_expr)?;
1085            unifier.unify(&t2, &t3)?;
1086            let out_ty = unifier.apply_type(&t2);
1087            let mut preds = p1;
1088            preds.extend(p2);
1089            preds.extend(p3);
1090            Ok((preds, out_ty))
1091        }
1092        Expr::Tuple(_, elems) => {
1093            let mut preds = Vec::new();
1094            let mut types = Vec::new();
1095            for elem in elems {
1096                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
1097                preds.extend(p1);
1098                types.push(unifier.apply_type(&t1));
1099            }
1100            let tuple_ty = Type::tuple(types);
1101            Ok((preds, tuple_ty))
1102        }
1103        Expr::List(_, elems) => {
1104            let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("a"))));
1105            let mut preds = Vec::new();
1106            for elem in elems {
1107                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
1108                unifier.unify(&t1, &elem_tv)?;
1109                preds.extend(p1);
1110            }
1111            let list_ty = Type::app(
1112                Type::builtin(BuiltinTypeId::List),
1113                unifier.apply_type(&elem_tv),
1114            );
1115            Ok((preds, list_ty))
1116        }
1117        Expr::Dict(_, kvs) => {
1118            let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("v"))));
1119            let mut preds = Vec::new();
1120            for v in kvs.values() {
1121                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
1122                unifier.unify(&t1, &elem_tv)?;
1123                preds.extend(p1);
1124            }
1125            let dict_ty = Type::app(
1126                Type::builtin(BuiltinTypeId::Dict),
1127                unifier.apply_type(&elem_tv),
1128            );
1129            Ok((preds, dict_ty))
1130        }
1131        Expr::Match(_, scrutinee, arms) => {
1132            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, scrutinee.as_ref())?;
1133            let mut preds = p1;
1134            let res_ty = Type::var(supply.fresh(Some(Symbol::intern("match"))));
1135            let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
1136
1137            for (pat, expr) in arms {
1138                let scrutinee_ty = unifier.apply_type(&t1);
1139                let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
1140                preds.extend(p_pat);
1141
1142                let mut env_arm = env.clone();
1143                for (name, ty) in binds {
1144                    env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1145                }
1146                let mut known_arm = known.clone();
1147                if let Expr::Var(var) = scrutinee.as_ref() {
1148                    match pat {
1149                        Pattern::Named(_, name, _) => {
1150                            let name_sym = name.to_dotted_symbol();
1151                            if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1152                                known_arm.insert(
1153                                    var.name.clone(),
1154                                    KnownVariant {
1155                                        adt: adt.name.clone(),
1156                                        variant: name_sym,
1157                                    },
1158                                );
1159                            } else {
1160                                known_arm.remove(&var.name);
1161                            }
1162                        }
1163                        _ => {
1164                            known_arm.remove(&var.name);
1165                        }
1166                    }
1167                }
1168                let (p_expr, t_expr) =
1169                    infer_expr_type(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1170                unifier.unify(&res_ty, &t_expr)?;
1171                preds.extend(p_expr);
1172            }
1173
1174            let scrutinee_ty = unifier.apply_type(&t1);
1175            check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1176            let out_ty = unifier.apply_type(&res_ty);
1177            Ok((preds, out_ty))
1178        }
1179        Expr::Ann(_, expr, ann) => {
1180            let ann_ty = type_from_annotation_expr(adts, ann)?;
1181            match expr.as_ref() {
1182                Expr::RecordUpdate(_, base, updates) => {
1183                    let (preds, out_ty) = infer_record_update_type_with_hint(
1184                        unifier,
1185                        supply,
1186                        env,
1187                        adts,
1188                        known,
1189                        base.as_ref(),
1190                        updates,
1191                        &ann_ty,
1192                    )?;
1193                    Ok((preds, out_ty))
1194                }
1195                _ => {
1196                    let (preds, expr_ty) =
1197                        infer_expr_type(unifier, supply, env, adts, known, expr)?;
1198                    unifier.unify(&expr_ty, &ann_ty)?;
1199                    let out_ty = unifier.apply_type(&ann_ty);
1200                    Ok((preds, out_ty))
1201                }
1202            }
1203        }
1204    }
1205}
1206
1207fn infer_expr(
1208    unifier: &mut Unifier,
1209    supply: &mut TypeVarSupply,
1210    env: &TypeEnv,
1211    adts: &BTreeMap<Symbol, AdtDecl>,
1212    known: &KnownVariants,
1213    expr: &Expr,
1214) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
1215    let span = *expr.span();
1216    let res = unifier.with_infer_depth(span, |unifier| {
1217        (|| match expr {
1218            Expr::Bool(_, v) => {
1219                let t = Type::builtin(BuiltinTypeId::Bool);
1220                Ok((
1221                    vec![],
1222                    t.clone(),
1223                    TypedExpr::new(t, TypedExprKind::Bool(*v)),
1224                ))
1225            }
1226            Expr::Uint(_, v) => {
1227                let t = Type::var(supply.fresh(Some(Symbol::intern("n"))));
1228                Ok((
1229                    vec![Predicate::new("Integral", t.clone())],
1230                    t.clone(),
1231                    TypedExpr::new(t, TypedExprKind::Uint(*v)),
1232                ))
1233            }
1234            Expr::Int(_, v) => {
1235                let t = Type::var(supply.fresh(Some(Symbol::intern("n"))));
1236                Ok((
1237                    vec![
1238                        Predicate::new("Integral", t.clone()),
1239                        Predicate::new("AdditiveGroup", t.clone()),
1240                    ],
1241                    t.clone(),
1242                    TypedExpr::new(t, TypedExprKind::Int(*v)),
1243                ))
1244            }
1245            Expr::Float(_, v) => {
1246                let t = Type::builtin(BuiltinTypeId::F32);
1247                Ok((
1248                    vec![],
1249                    t.clone(),
1250                    TypedExpr::new(t, TypedExprKind::Float(*v)),
1251                ))
1252            }
1253            Expr::String(_, v) => {
1254                let t = Type::builtin(BuiltinTypeId::String);
1255                Ok((
1256                    vec![],
1257                    t.clone(),
1258                    TypedExpr::new(t, TypedExprKind::String(v.clone())),
1259                ))
1260            }
1261            Expr::Uuid(_, v) => {
1262                let t = Type::builtin(BuiltinTypeId::Uuid);
1263                Ok((
1264                    vec![],
1265                    t.clone(),
1266                    TypedExpr::new(t, TypedExprKind::Uuid(*v)),
1267                ))
1268            }
1269            Expr::DateTime(_, v) => {
1270                let t = Type::builtin(BuiltinTypeId::DateTime);
1271                Ok((
1272                    vec![],
1273                    t.clone(),
1274                    TypedExpr::new(t, TypedExprKind::DateTime(*v)),
1275                ))
1276            }
1277            Expr::Hole(_) => {
1278                let t = Type::var(supply.fresh(Some(Symbol::intern("hole"))));
1279                Ok((vec![], t.clone(), TypedExpr::new(t, TypedExprKind::Hole)))
1280            }
1281            Expr::Var(var) => {
1282                let schemes = env
1283                    .lookup(&var.name)
1284                    .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1285                if schemes.len() == 1 {
1286                    let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1287                    let (preds, t) = instantiate(&scheme, supply);
1288                    let typed = TypedExpr::new(
1289                        t.clone(),
1290                        TypedExprKind::Var {
1291                            name: var.name.clone(),
1292                            overloads: vec![],
1293                        },
1294                    );
1295                    Ok((preds, t, typed))
1296                } else {
1297                    let mut overloads = Vec::new();
1298                    for scheme in schemes {
1299                        if !scheme.preds.is_empty() {
1300                            return Err(TypeError::AmbiguousOverload(var.name.clone()));
1301                        }
1302
1303                        let scheme = apply_scheme_with_unifier(scheme, unifier);
1304                        let (preds, typ) = instantiate(&scheme, supply);
1305                        if !preds.is_empty() {
1306                            return Err(TypeError::AmbiguousOverload(var.name.clone()));
1307                        }
1308                        overloads.push(typ);
1309                    }
1310                    let t = Type::var(supply.fresh(Some(var.name.clone())));
1311                    let typed = TypedExpr::new(
1312                        t.clone(),
1313                        TypedExprKind::Var {
1314                            name: var.name.clone(),
1315                            overloads,
1316                        },
1317                    );
1318                    Ok((vec![], t, typed))
1319                }
1320            }
1321            Expr::Lam(..) => {
1322                let (params, constraints, body) = collect_lambda_chain(expr);
1323                let mut ann_vars = BTreeMap::new();
1324                let mut param_tys = Vec::with_capacity(params.len());
1325                for (name, ann) in &params {
1326                    let param_ty = match ann {
1327                        Some(ann) => {
1328                            type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?
1329                        }
1330                        None => Type::var(supply.fresh(Some(name.clone()))),
1331                    };
1332                    param_tys.push((name.clone(), param_ty));
1333                }
1334
1335                let mut env1 = env.clone();
1336                let mut known_body = known.clone();
1337                for (name, param_ty) in &param_tys {
1338                    env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
1339                    known_body.remove(name);
1340                }
1341
1342                let (mut preds, body_ty, typed_body) =
1343                    infer_expr(unifier, supply, &env1, adts, &known_body, body)?;
1344                let constraint_preds =
1345                    predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
1346                preds.extend(constraint_preds);
1347
1348                let mut typed = typed_body;
1349                let mut fun_ty = unifier.apply_type(&body_ty);
1350                for (name, param_ty) in param_tys.iter().rev() {
1351                    fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
1352                    typed = TypedExpr::new(
1353                        fun_ty.clone(),
1354                        TypedExprKind::Lam {
1355                            param: name.clone(),
1356                            body: Arc::new(typed),
1357                        },
1358                    );
1359                }
1360
1361                Ok((preds, fun_ty, typed))
1362            }
1363            Expr::App(..) => {
1364                if let Some((leaf, frames)) = collect_tail_app_chain(expr) {
1365                    return infer_tail_app_chain_typed(
1366                        unifier, supply, env, adts, known, leaf, frames,
1367                    );
1368                }
1369                let (head, args) = collect_app_chain(expr);
1370                let mut state = infer_typed_app_head(unifier, supply, env, adts, known, head)?;
1371                for arg in args {
1372                    infer_typed_app_expr_arg(unifier, supply, env, adts, known, &mut state, arg)?;
1373                }
1374                Ok((state.preds, state.func_ty, state.typed))
1375            }
1376            Expr::Project(_, base, field) => {
1377                let (p1, t1, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
1378                let base_ty = unifier.apply_type(&t1);
1379                let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
1380                let field_ty =
1381                    resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
1382                let typed = TypedExpr::new(
1383                    field_ty.clone(),
1384                    TypedExprKind::Project {
1385                        expr: Arc::new(typed_base),
1386                        field: field.clone(),
1387                    },
1388                );
1389                Ok((p1, field_ty, typed))
1390            }
1391            Expr::RecordUpdate(_, base, updates) => {
1392                let (p_base, t_base, typed_base) =
1393                    infer_expr(unifier, supply, env, adts, known, base)?;
1394                let base_ty = unifier.apply_type(&t_base);
1395                let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
1396                let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
1397                let (result_ty, fields) = resolve_record_update(
1398                    unifier,
1399                    supply,
1400                    adts,
1401                    &base_ty,
1402                    known_variant,
1403                    &update_fields,
1404                )?;
1405                let expected: BTreeMap<_, _> = fields.into_iter().collect();
1406
1407                let mut preds = p_base;
1408                let mut typed_updates = BTreeMap::new();
1409                for (k, v) in updates {
1410                    let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
1411                        field: k.clone(),
1412                        typ: result_ty.to_string(),
1413                    })?;
1414                    let (p1, t1, typed_v) =
1415                        infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
1416                    unifier.unify(&t1, expected_ty)?;
1417                    preds.extend(p1);
1418                    typed_updates.insert(k.clone(), Arc::new(typed_v));
1419                }
1420                let typed = TypedExpr::new(
1421                    result_ty.clone(),
1422                    TypedExprKind::RecordUpdate {
1423                        base: Arc::new(typed_base),
1424                        updates: typed_updates,
1425                    },
1426                );
1427                Ok((preds, result_ty, typed))
1428            }
1429            Expr::Let(..) => {
1430                let mut bindings = Vec::new();
1431                let mut cur = expr;
1432                while let Expr::Let(_, v, ann, d, b) = cur {
1433                    bindings.push((v.clone(), ann.clone(), d.clone()));
1434                    cur = b.as_ref();
1435                }
1436
1437                let mut env_cur = env.clone();
1438                let mut known_cur = known.clone();
1439                let mut typed_defs = Vec::new();
1440                for (v, ann, d) in bindings {
1441                    let (p1, t1, typed_def) = if let Some(ref ann_expr) = ann {
1442                        let mut ann_vars = BTreeMap::new();
1443                        let ann_ty =
1444                            type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
1445                        match d.as_ref() {
1446                            Expr::RecordUpdate(_, base, updates) => {
1447                                infer_record_update_typed_with_hint(
1448                                    unifier,
1449                                    supply,
1450                                    &env_cur,
1451                                    adts,
1452                                    &known_cur,
1453                                    base.as_ref(),
1454                                    updates,
1455                                    &ann_ty,
1456                                )?
1457                            }
1458                            _ => {
1459                                let (p1, t1, typed_def) =
1460                                    infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?;
1461                                unifier.unify(&t1, &ann_ty)?;
1462                                (p1, t1, typed_def)
1463                            }
1464                        }
1465                    } else {
1466                        infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?
1467                    };
1468                    let def_ty = unifier.apply_type(&t1);
1469                    let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
1470                        monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
1471                    } else {
1472                        let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
1473                        reject_ambiguous_scheme(&scheme)?;
1474                        scheme
1475                    };
1476                    env_cur.extend(v.name.clone(), scheme);
1477                    if let Some(known_variant) =
1478                        known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
1479                    {
1480                        known_cur.insert(
1481                            v.name.clone(),
1482                            KnownVariant {
1483                                adt: known_variant.adt,
1484                                variant: known_variant.variant,
1485                            },
1486                        );
1487                    } else {
1488                        known_cur.remove(&v.name);
1489                    }
1490                    typed_defs.push((v.name.clone(), typed_def));
1491                }
1492
1493                let (p_body, t_body, typed_body) =
1494                    infer_expr(unifier, supply, &env_cur, adts, &known_cur, cur)?;
1495
1496                let mut typed = typed_body;
1497                for (name, def) in typed_defs.into_iter().rev() {
1498                    typed = TypedExpr::new(
1499                        t_body.clone(),
1500                        TypedExprKind::Let {
1501                            name,
1502                            def: Arc::new(def),
1503                            body: Arc::new(typed),
1504                        },
1505                    );
1506                }
1507                Ok((p_body, t_body, typed))
1508            }
1509            Expr::LetRec(_, bindings, body) => {
1510                let mut env_seed = env.clone();
1511                let mut known_seed = known.clone();
1512                let mut binding_tys = BTreeMap::new();
1513                for (var, _ann, _def) in bindings {
1514                    let tv = Type::var(supply.fresh(Some(var.name.clone())));
1515                    env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
1516                    known_seed.remove(&var.name);
1517                    binding_tys.insert(var.name.clone(), tv);
1518                }
1519
1520                let mut inferred_defs = Vec::with_capacity(bindings.len());
1521                for (var, ann, def) in bindings {
1522                    let (preds, def_ty, typed_def) =
1523                        infer_expr(unifier, supply, &env_seed, adts, &known_seed, def)?;
1524                    if let Some(ann) = ann {
1525                        let mut ann_vars = BTreeMap::new();
1526                        let ann_ty =
1527                            type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
1528                        unifier.unify(&def_ty, &ann_ty)?;
1529                    }
1530                    let binding_ty = binding_tys
1531                        .get(&var.name)
1532                        .cloned()
1533                        .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1534                    unifier.unify(&binding_ty, &def_ty)?;
1535                    let resolved_ty = unifier.apply_type(&binding_ty);
1536
1537                    if let Some(known_variant) =
1538                        known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
1539                    {
1540                        known_seed.insert(
1541                            var.name.clone(),
1542                            KnownVariant {
1543                                adt: known_variant.adt,
1544                                variant: known_variant.variant,
1545                            },
1546                        );
1547                    } else {
1548                        known_seed.remove(&var.name);
1549                    }
1550                    inferred_defs.push((var.name.clone(), preds, resolved_ty, typed_def));
1551                }
1552
1553                let mut env_body = env.clone();
1554                let mut typed_bindings = Vec::with_capacity(inferred_defs.len());
1555                for (name, preds, def_ty, typed_def) in inferred_defs {
1556                    let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
1557                    reject_ambiguous_scheme(&scheme)?;
1558                    env_body.extend(name.clone(), scheme);
1559                    typed_bindings.push((name, Arc::new(typed_def)));
1560                }
1561
1562                let (p_body, t_body, typed_body) =
1563                    infer_expr(unifier, supply, &env_body, adts, &known_seed, body)?;
1564                let typed = TypedExpr::new(
1565                    t_body.clone(),
1566                    TypedExprKind::LetRec {
1567                        bindings: typed_bindings,
1568                        body: Arc::new(typed_body),
1569                    },
1570                );
1571                Ok((p_body, t_body, typed))
1572            }
1573            Expr::Ite(_, cond, then_expr, else_expr) => {
1574                let (p1, t1, typed_cond) = infer_expr(unifier, supply, env, adts, known, cond)?;
1575                unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
1576                let (p2, t2, typed_then) =
1577                    infer_expr(unifier, supply, env, adts, known, then_expr)?;
1578                let (p3, t3, typed_else) =
1579                    infer_expr(unifier, supply, env, adts, known, else_expr)?;
1580                unifier.unify(&t2, &t3)?;
1581                let out_ty = unifier.apply_type(&t2);
1582                let mut preds = p1;
1583                preds.extend(p2);
1584                preds.extend(p3);
1585                let typed = TypedExpr::new(
1586                    out_ty.clone(),
1587                    TypedExprKind::Ite {
1588                        cond: Arc::new(typed_cond),
1589                        then_expr: Arc::new(typed_then),
1590                        else_expr: Arc::new(typed_else),
1591                    },
1592                );
1593                Ok((preds, out_ty, typed))
1594            }
1595            Expr::Tuple(_, elems) => {
1596                let mut preds = Vec::new();
1597                let mut types = Vec::new();
1598                let mut typed_elems = Vec::new();
1599                for elem in elems {
1600                    let (p1, t1, typed_elem) = infer_expr(unifier, supply, env, adts, known, elem)?;
1601                    preds.extend(p1);
1602                    types.push(unifier.apply_type(&t1));
1603                    typed_elems.push(Arc::new(typed_elem));
1604                }
1605                let tuple_ty = Type::tuple(types);
1606                let typed = TypedExpr::new(tuple_ty.clone(), TypedExprKind::Tuple(typed_elems));
1607                Ok((preds, tuple_ty, typed))
1608            }
1609            Expr::List(_, elems) => {
1610                let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("a"))));
1611                let mut preds = Vec::new();
1612                let mut typed_elems = Vec::new();
1613                for elem in elems {
1614                    let (p1, t1, typed_elem) = infer_expr(unifier, supply, env, adts, known, elem)?;
1615                    unifier.unify(&t1, &elem_tv)?;
1616                    preds.extend(p1);
1617                    typed_elems.push(Arc::new(typed_elem));
1618                }
1619                let list_ty = Type::app(
1620                    Type::builtin(BuiltinTypeId::List),
1621                    unifier.apply_type(&elem_tv),
1622                );
1623                let typed = TypedExpr::new(list_ty.clone(), TypedExprKind::List(typed_elems));
1624                Ok((preds, list_ty, typed))
1625            }
1626            Expr::Dict(_, kvs) => {
1627                let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("v"))));
1628                let mut preds = Vec::new();
1629                let mut typed_kvs = BTreeMap::new();
1630                for (k, v) in kvs {
1631                    let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v)?;
1632                    unifier.unify(&t1, &elem_tv)?;
1633                    preds.extend(p1);
1634                    typed_kvs.insert(k.clone(), Arc::new(typed_v));
1635                }
1636                let dict_ty = Type::app(
1637                    Type::builtin(BuiltinTypeId::Dict),
1638                    unifier.apply_type(&elem_tv),
1639                );
1640                let typed = TypedExpr::new(dict_ty.clone(), TypedExprKind::Dict(typed_kvs));
1641                Ok((preds, dict_ty, typed))
1642            }
1643            Expr::Match(_, scrutinee, arms) => {
1644                let (p1, t1, typed_scrutinee) =
1645                    infer_expr(unifier, supply, env, adts, known, scrutinee)?;
1646                let mut preds = p1;
1647                let mut typed_arms = Vec::new();
1648                let res_ty = Type::var(supply.fresh(Some(Symbol::intern("match"))));
1649                let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
1650
1651                for (pat, expr) in arms {
1652                    let scrutinee_ty = unifier.apply_type(&t1);
1653                    let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
1654                    preds.extend(p_pat);
1655
1656                    let mut env_arm = env.clone();
1657                    for (name, ty) in binds {
1658                        env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1659                    }
1660                    let mut known_arm = known.clone();
1661                    if let Expr::Var(var) = scrutinee.as_ref() {
1662                        match pat {
1663                            Pattern::Named(_, name, _) => {
1664                                let name_sym = name.to_dotted_symbol();
1665                                if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1666                                    known_arm.insert(
1667                                        var.name.clone(),
1668                                        KnownVariant {
1669                                            adt: adt.name.clone(),
1670                                            variant: name_sym,
1671                                        },
1672                                    );
1673                                } else {
1674                                    known_arm.remove(&var.name);
1675                                }
1676                            }
1677                            _ => {
1678                                known_arm.remove(&var.name);
1679                            }
1680                        }
1681                    }
1682                    let (p_expr, t_expr, typed_expr) =
1683                        infer_expr(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1684                    unifier.unify(&res_ty, &t_expr)?;
1685                    preds.extend(p_expr);
1686                    typed_arms.push((pat.clone(), Arc::new(typed_expr)));
1687                }
1688
1689                let scrutinee_ty = unifier.apply_type(&t1);
1690                check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1691                let out_ty = unifier.apply_type(&res_ty);
1692                let typed = TypedExpr::new(
1693                    out_ty.clone(),
1694                    TypedExprKind::Match {
1695                        scrutinee: Arc::new(typed_scrutinee),
1696                        arms: typed_arms,
1697                    },
1698                );
1699                Ok((preds, out_ty, typed))
1700            }
1701            Expr::Ann(_, expr, ann) => {
1702                let ann_ty = type_from_annotation_expr(adts, ann)?;
1703                match expr.as_ref() {
1704                    Expr::RecordUpdate(_, base, updates) => infer_record_update_typed_with_hint(
1705                        unifier,
1706                        supply,
1707                        env,
1708                        adts,
1709                        known,
1710                        base.as_ref(),
1711                        updates,
1712                        &ann_ty,
1713                    ),
1714                    _ => {
1715                        let (preds, expr_ty, typed_expr) =
1716                            infer_expr(unifier, supply, env, adts, known, expr)?;
1717                        unifier.unify(&expr_ty, &ann_ty)?;
1718                        let out_ty = unifier.apply_type(&ann_ty);
1719                        Ok((preds, out_ty, typed_expr))
1720                    }
1721                }
1722            }
1723        })()
1724    });
1725    res.map_err(|err| err.with_span(&span))
1726}
1727
1728fn ctor_lookup<'a>(
1729    adts: &'a BTreeMap<Symbol, AdtDecl>,
1730    name: &Symbol,
1731) -> Option<(&'a AdtDecl, &'a AdtVariant)> {
1732    let mut found = None;
1733    for adt in adts.values() {
1734        if let Some(variant) = adt.variants.iter().find(|v| &v.name == name) {
1735            if found.is_some() {
1736                return None;
1737            }
1738            found = Some((adt, variant));
1739        }
1740    }
1741    found
1742}
1743
1744fn record_fields(variant: &AdtVariant) -> Option<&[(Symbol, Type)]> {
1745    if variant.args.len() != 1 {
1746        return None;
1747    }
1748    match variant.args[0].as_ref() {
1749        TypeKind::Record(fields) => Some(fields),
1750        _ => None,
1751    }
1752}
1753
1754fn instantiate_variant_fields(
1755    adt: &AdtDecl,
1756    variant: &AdtVariant,
1757    supply: &mut TypeVarSupply,
1758) -> Option<(Type, Vec<(Symbol, Type)>)> {
1759    let fields = record_fields(variant)?;
1760    let mut subst = Subst::new_sync();
1761    for param in &adt.params {
1762        let fresh = Type::var(supply.fresh(param.var.name.clone()));
1763        subst = subst.insert(param.var.id, fresh);
1764    }
1765    let result_ty = adt.result_type().apply(&subst);
1766    let fields = fields
1767        .iter()
1768        .map(|(name, ty)| (name.clone(), ty.apply(&subst)))
1769        .collect();
1770    Some((result_ty, fields))
1771}
1772
1773fn known_variant_from_expr(
1774    expr: &Expr,
1775    expr_ty: &Type,
1776    adts: &BTreeMap<Symbol, AdtDecl>,
1777) -> Option<KnownVariant> {
1778    let mut expr = expr;
1779    while let Expr::Ann(_, inner, _) = expr {
1780        expr = inner.as_ref();
1781    }
1782    if matches!(expr_ty.as_ref(), TypeKind::Fun(..)) {
1783        return None;
1784    }
1785    let ctor = match expr {
1786        Expr::App(_, f, _) => match f.as_ref() {
1787            Expr::Var(var) => var.name.clone(),
1788            _ => return None,
1789        },
1790        _ => return None,
1791    };
1792    let (adt, variant) = ctor_lookup(adts, &ctor)?;
1793    record_fields(variant)?;
1794    Some(KnownVariant {
1795        adt: adt.name.clone(),
1796        variant: variant.name.clone(),
1797    })
1798}
1799
1800fn known_variant_from_expr_with_known(
1801    expr: &Expr,
1802    expr_ty: &Type,
1803    adts: &BTreeMap<Symbol, AdtDecl>,
1804    known: &KnownVariants,
1805) -> Option<KnownVariant> {
1806    let mut expr = expr;
1807    while let Expr::Ann(_, inner, _) = expr {
1808        expr = inner.as_ref();
1809    }
1810    match expr {
1811        Expr::Var(var) => known.get(&var.name).cloned(),
1812        Expr::RecordUpdate(_, base, _) => {
1813            known_variant_from_expr_with_known(base.as_ref(), expr_ty, adts, known)
1814        }
1815        _ => known_variant_from_expr(expr, expr_ty, adts),
1816    }
1817}
1818
1819fn select_record_variant<'a, F>(
1820    adts: &'a BTreeMap<Symbol, AdtDecl>,
1821    base_ty: &Type,
1822    known_variant: Option<KnownVariant>,
1823    field_for_errors: &Symbol,
1824    matches_fields: F,
1825) -> Result<(&'a AdtDecl, &'a AdtVariant), TypeError>
1826where
1827    F: Fn(&[(Symbol, Type)]) -> bool,
1828{
1829    if let Some(info) = known_variant {
1830        let adt = adts
1831            .get(&info.adt)
1832            .ok_or_else(|| TypeError::UnknownTypeName(info.adt.clone()))?;
1833        let variant = adt
1834            .variants
1835            .iter()
1836            .find(|v| v.name == info.variant)
1837            .ok_or_else(|| TypeError::UnknownField {
1838                field: field_for_errors.clone(),
1839                typ: base_ty.to_string(),
1840            })?;
1841        return Ok((adt, variant));
1842    }
1843
1844    if let Some(adt_name) = type_head_name(base_ty) {
1845        let adt = adts.get(adt_name).ok_or_else(|| TypeError::UnknownField {
1846            field: field_for_errors.clone(),
1847            typ: base_ty.to_string(),
1848        })?;
1849        if adt.variants.len() == 1 {
1850            return Ok((adt, &adt.variants[0]));
1851        }
1852        return Err(TypeError::FieldNotKnown {
1853            field: field_for_errors.clone(),
1854            typ: base_ty.to_string(),
1855        });
1856    }
1857
1858    if matches!(base_ty.as_ref(), TypeKind::Var(_)) {
1859        let mut candidates = Vec::new();
1860        for adt in adts.values() {
1861            if adt.variants.len() != 1 {
1862                continue;
1863            }
1864            let variant = &adt.variants[0];
1865            let Some(fields) = record_fields(variant) else {
1866                continue;
1867            };
1868            if matches_fields(fields) {
1869                candidates.push((adt, variant));
1870            }
1871        }
1872        if candidates.len() == 1 {
1873            return Ok(candidates.remove(0));
1874        }
1875        if candidates.is_empty() {
1876            return Err(TypeError::UnknownField {
1877                field: field_for_errors.clone(),
1878                typ: base_ty.to_string(),
1879            });
1880        }
1881        return Err(TypeError::FieldNotKnown {
1882            field: field_for_errors.clone(),
1883            typ: base_ty.to_string(),
1884        });
1885    }
1886
1887    Err(TypeError::UnknownField {
1888        field: field_for_errors.clone(),
1889        typ: base_ty.to_string(),
1890    })
1891}
1892
1893fn resolve_record_update(
1894    unifier: &mut Unifier,
1895    supply: &mut TypeVarSupply,
1896    adts: &BTreeMap<Symbol, AdtDecl>,
1897    base_ty: &Type,
1898    known_variant: Option<KnownVariant>,
1899    update_fields: &[Symbol],
1900) -> Result<(Type, Vec<(Symbol, Type)>), TypeError> {
1901    if let TypeKind::Record(fields) = base_ty.as_ref() {
1902        return Ok((base_ty.clone(), fields.clone()));
1903    }
1904
1905    let field_for_errors = update_fields
1906        .first()
1907        .cloned()
1908        .unwrap_or_else(|| Symbol::intern("_"));
1909
1910    let (adt, variant) =
1911        select_record_variant(adts, base_ty, known_variant, &field_for_errors, |fields| {
1912            update_fields
1913                .iter()
1914                .all(|field| fields.iter().any(|(name, _)| name == field))
1915        })?;
1916
1917    let (result_ty, fields) =
1918        instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1919            TypeError::UnknownField {
1920                field: field_for_errors.clone(),
1921                typ: base_ty.to_string(),
1922            }
1923        })?;
1924
1925    for field in update_fields {
1926        if fields.iter().all(|(name, _)| name != field) {
1927            return Err(TypeError::UnknownField {
1928                field: field.clone(),
1929                typ: base_ty.to_string(),
1930            });
1931        }
1932    }
1933
1934    unifier.unify(base_ty, &result_ty)?;
1935    let result_ty = unifier.apply_type(&result_ty);
1936    let fields = fields
1937        .into_iter()
1938        .map(|(name, ty)| (name, unifier.apply_type(&ty)))
1939        .collect();
1940    Ok((result_ty, fields))
1941}
1942
1943fn resolve_projection(
1944    unifier: &mut Unifier,
1945    supply: &mut TypeVarSupply,
1946    adts: &BTreeMap<Symbol, AdtDecl>,
1947    base_ty: &Type,
1948    known_variant: Option<KnownVariant>,
1949    field: &Symbol,
1950) -> Result<Type, TypeError> {
1951    if let Ok(index) = field.as_ref().parse::<usize>() {
1952        let elem_ty = match base_ty.as_ref() {
1953            TypeKind::Tuple(elems) => {
1954                elems
1955                    .get(index)
1956                    .cloned()
1957                    .ok_or_else(|| TypeError::UnknownField {
1958                        field: field.clone(),
1959                        typ: base_ty.to_string(),
1960                    })?
1961            }
1962            TypeKind::Var(_) => {
1963                let mut elems = Vec::with_capacity(index + 1);
1964                for _ in 0..=index {
1965                    elems.push(Type::var(supply.fresh(Some(Symbol::intern("t")))));
1966                }
1967                let tuple_ty = Type::tuple(elems.clone());
1968                unifier.unify(base_ty, &tuple_ty)?;
1969                elems[index].clone()
1970            }
1971            _ => {
1972                return Err(TypeError::UnknownField {
1973                    field: field.clone(),
1974                    typ: base_ty.to_string(),
1975                });
1976            }
1977        };
1978        return Ok(unifier.apply_type(&elem_ty));
1979    }
1980
1981    let (adt, variant) = select_record_variant(adts, base_ty, known_variant, field, |fields| {
1982        fields.iter().any(|(name, _)| name == field)
1983    })?;
1984
1985    let (result_ty, fields) =
1986        instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1987            TypeError::UnknownField {
1988                field: field.clone(),
1989                typ: base_ty.to_string(),
1990            }
1991        })?;
1992    let field_ty = fields
1993        .iter()
1994        .find(|(name, _)| name == field)
1995        .map(|(_, ty)| ty.clone())
1996        .ok_or_else(|| TypeError::UnknownField {
1997            field: field.clone(),
1998            typ: base_ty.to_string(),
1999        })?;
2000    unifier.unify(base_ty, &result_ty)?;
2001    Ok(unifier.apply_type(&field_ty))
2002}
2003
2004fn decompose_fun(typ: &Type, arity: usize) -> Option<(Vec<Type>, Type)> {
2005    let mut args = Vec::with_capacity(arity);
2006    let mut cur = typ.clone();
2007    for _ in 0..arity {
2008        match cur.as_ref() {
2009            TypeKind::Fun(a, b) => {
2010                args.push(a.clone());
2011                cur = b.clone();
2012            }
2013            _ => return None,
2014        }
2015    }
2016    Some((args, cur))
2017}
2018
2019type InferPatternResult = (Vec<Predicate>, Vec<(Symbol, Type)>);
2020
2021fn infer_pattern(
2022    unifier: &mut Unifier,
2023    supply: &mut TypeVarSupply,
2024    env: &TypeEnv,
2025    pat: &Pattern,
2026    scrutinee_ty: &Type,
2027) -> Result<InferPatternResult, TypeError> {
2028    let span = *pat.span();
2029    let res = (|| match pat {
2030        Pattern::Wildcard(..) => Ok((vec![], vec![])),
2031        Pattern::Var(var) => Ok((
2032            vec![],
2033            vec![(var.name.clone(), unifier.apply_type(scrutinee_ty))],
2034        )),
2035        Pattern::Named(_, name, ps) => {
2036            let ctor_name = name.to_dotted_symbol();
2037            let schemes = env
2038                .lookup(&ctor_name)
2039                .ok_or_else(|| TypeError::UnknownVar(ctor_name.clone()))?;
2040            if schemes.len() != 1 {
2041                return Err(TypeError::AmbiguousOverload(ctor_name));
2042            }
2043            let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
2044            let (preds, ctor_ty) = instantiate(&scheme, supply);
2045            let (arg_tys, res_ty) = decompose_fun(&ctor_ty, ps.len())
2046                .ok_or(TypeError::UnsupportedExpr("pattern constructor"))?;
2047            unifier.unify(&res_ty, scrutinee_ty)?;
2048            let mut all_preds = preds;
2049            let mut bindings = Vec::new();
2050            for (p, arg_ty) in ps.iter().zip(arg_tys.iter()) {
2051                let arg_ty = unifier.apply_type(arg_ty);
2052                let (p1, binds1) = infer_pattern(unifier, supply, env, p, &arg_ty)?;
2053                all_preds.extend(p1);
2054                bindings.extend(binds1);
2055            }
2056            let bindings = bindings
2057                .into_iter()
2058                .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2059                .collect();
2060            Ok((all_preds, bindings))
2061        }
2062        Pattern::List(_, ps) => {
2063            let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("a"))));
2064            let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2065            unifier.unify(scrutinee_ty, &list_ty)?;
2066            let mut preds = Vec::new();
2067            let mut bindings = Vec::new();
2068            for p in ps {
2069                let elem_ty = unifier.apply_type(&elem_tv);
2070                let (p1, binds1) = infer_pattern(unifier, supply, env, p, &elem_ty)?;
2071                preds.extend(p1);
2072                bindings.extend(binds1);
2073            }
2074            let bindings = bindings
2075                .into_iter()
2076                .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2077                .collect();
2078            Ok((preds, bindings))
2079        }
2080        Pattern::Cons(_, head, tail) => {
2081            let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("a"))));
2082            let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2083            unifier.unify(scrutinee_ty, &list_ty)?;
2084            let mut preds = Vec::new();
2085            let mut bindings = Vec::new();
2086
2087            let head_ty = unifier.apply_type(&elem_tv);
2088            let (p1, binds1) = infer_pattern(unifier, supply, env, head, &head_ty)?;
2089            preds.extend(p1);
2090            bindings.extend(binds1);
2091
2092            let tail_ty = Type::app(
2093                Type::builtin(BuiltinTypeId::List),
2094                unifier.apply_type(&elem_tv),
2095            );
2096            let (p2, binds2) = infer_pattern(unifier, supply, env, tail, &tail_ty)?;
2097            preds.extend(p2);
2098            bindings.extend(binds2);
2099
2100            let bindings = bindings
2101                .into_iter()
2102                .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2103                .collect();
2104            Ok((preds, bindings))
2105        }
2106        Pattern::Tuple(_, elems) => {
2107            let mut elem_tys: Vec<Type> = (0..elems.len())
2108                .map(|i| Type::var(supply.fresh(Some(Symbol::intern(&format!("t{i}"))))))
2109                .collect();
2110            let expected = Type::tuple(elem_tys.clone());
2111            unifier.unify(scrutinee_ty, &expected)?;
2112            elem_tys = elem_tys
2113                .into_iter()
2114                .map(|t| unifier.apply_type(&t))
2115                .collect();
2116
2117            let mut preds = Vec::new();
2118            let mut bindings = Vec::new();
2119            for (p, ty) in elems.iter().zip(elem_tys.iter()) {
2120                let (p_preds, p_binds) = infer_pattern(unifier, supply, env, p, ty)?;
2121                preds.extend(p_preds);
2122                bindings.extend(p_binds);
2123            }
2124            let bindings = bindings
2125                .into_iter()
2126                .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2127                .collect();
2128            Ok((preds, bindings))
2129        }
2130        Pattern::Dict(_, fields) => {
2131            if let TypeKind::Record(ty_fields) = scrutinee_ty.as_ref() {
2132                let mut preds = Vec::new();
2133                let mut bindings = Vec::new();
2134                for (key, pat) in fields {
2135                    let ty = ty_fields
2136                        .iter()
2137                        .find(|(name, _)| name == key)
2138                        .map(|(_, ty)| unifier.apply_type(ty))
2139                        .ok_or_else(|| TypeError::UnknownField {
2140                            field: key.clone(),
2141                            typ: scrutinee_ty.to_string(),
2142                        })?;
2143                    let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &ty)?;
2144                    preds.extend(p_preds);
2145                    bindings.extend(p_binds);
2146                }
2147                let bindings = bindings
2148                    .into_iter()
2149                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2150                    .collect();
2151                Ok((preds, bindings))
2152            } else {
2153                let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("v"))));
2154                let dict_ty = Type::app(Type::builtin(BuiltinTypeId::Dict), elem_tv.clone());
2155                unifier.unify(scrutinee_ty, &dict_ty)?;
2156                let elem_ty = unifier.apply_type(&elem_tv);
2157
2158                let mut preds = Vec::new();
2159                let mut bindings = Vec::new();
2160                for (_key, pat) in fields {
2161                    let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &elem_ty)?;
2162                    preds.extend(p_preds);
2163                    bindings.extend(p_binds);
2164                }
2165                let bindings = bindings
2166                    .into_iter()
2167                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2168                    .collect();
2169                Ok((preds, bindings))
2170            }
2171        }
2172    })();
2173    res.map_err(|err| err.with_span(&span))
2174}
2175
2176fn type_head_name(typ: &Type) -> Option<&Symbol> {
2177    let mut cur = typ;
2178    while let TypeKind::App(head, _) = cur.as_ref() {
2179        cur = head;
2180    }
2181    match cur.as_ref() {
2182        TypeKind::Con(tc) => tc.user_name(),
2183        _ => None,
2184    }
2185}
2186
2187fn adt_name_from_patterns(
2188    adts: &BTreeMap<Symbol, AdtDecl>,
2189    patterns: &[Pattern],
2190) -> Option<Symbol> {
2191    let mut candidate: Option<Symbol> = None;
2192    for pat in patterns {
2193        let next = match pat {
2194            Pattern::Named(_, name, _) => {
2195                let name_sym = name.to_dotted_symbol();
2196                ctor_lookup(adts, &name_sym).map(|(adt, _)| adt.name.clone())
2197            }
2198            Pattern::List(..) | Pattern::Cons(..) => Some(Symbol::intern("List")),
2199            _ => None,
2200        };
2201        if let Some(next) = next {
2202            match &candidate {
2203                None => candidate = Some(next),
2204                Some(prev) if *prev == next => {}
2205                Some(_) => return None,
2206            }
2207        }
2208    }
2209    candidate
2210}
2211
2212fn check_match_exhaustive(
2213    adts: &BTreeMap<Symbol, AdtDecl>,
2214    scrutinee_ty: &Type,
2215    patterns: &[Pattern],
2216) -> Result<(), TypeError> {
2217    if patterns
2218        .iter()
2219        .any(|p| matches!(p, Pattern::Wildcard(..) | Pattern::Var(_)))
2220    {
2221        return Ok(());
2222    }
2223    let adt_name = match type_head_name(scrutinee_ty).cloned() {
2224        Some(name) => name,
2225        None => match adt_name_from_patterns(adts, patterns) {
2226            Some(name) => name,
2227            None => return Ok(()),
2228        },
2229    };
2230    let adt = match adts.get(&adt_name) {
2231        Some(adt) => adt,
2232        None => return Ok(()),
2233    };
2234    let ctor_names: BTreeSet<Symbol> = adt.variants.iter().map(|v| v.name.clone()).collect();
2235    if ctor_names.is_empty() {
2236        return Ok(());
2237    }
2238    let mut covered = BTreeSet::new();
2239    for pat in patterns {
2240        match pat {
2241            Pattern::Named(_, name, _) => {
2242                let name_sym = name.to_dotted_symbol();
2243                if ctor_names.contains(&name_sym) {
2244                    covered.insert(name_sym);
2245                }
2246            }
2247            Pattern::List(_, elems) if adt_name.as_ref() == "List" && elems.is_empty() => {
2248                covered.insert(Symbol::intern("Empty"));
2249            }
2250            Pattern::Cons(..) if adt_name.as_ref() == "List" => {
2251                covered.insert(Symbol::intern("Cons"));
2252            }
2253            _ => {}
2254        }
2255    }
2256    let mut missing: Vec<Symbol> = ctor_names.difference(&covered).cloned().collect();
2257    if missing.is_empty() {
2258        return Ok(());
2259    }
2260    missing.sort();
2261    Err(TypeError::NonExhaustiveMatch {
2262        typ: scrutinee_ty.to_string(),
2263        missing,
2264    })
2265}