Skip to main content

rexlang_typesystem/
typesystem.rs

1//! Core type system implementation for Rex.
2
3use std::collections::{BTreeMap, HashMap, HashSet};
4use std::fmt::{self, Display, Formatter};
5use std::sync::Arc;
6
7use chrono::{DateTime, Utc};
8use rexlang_ast::expr::{
9    ClassDecl, ClassMethodSig, Decl, DeclareFnDecl, Expr, FnDecl, InstanceDecl, InstanceMethodImpl,
10    Pattern, Scope, Symbol, TypeConstraint, TypeDecl, TypeExpr, intern, sym,
11};
12use rexlang_lexer::span::Span;
13use rexlang_util::{GasMeter, OutOfGas};
14use rpds::HashTrieMapSync;
15use uuid::Uuid;
16
17use crate::prelude;
18
19pub type TypeVarId = usize;
20
21#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
22pub enum BuiltinTypeId {
23    U8,
24    U16,
25    U32,
26    U64,
27    I8,
28    I16,
29    I32,
30    I64,
31    F32,
32    F64,
33    Bool,
34    String,
35    Uuid,
36    DateTime,
37    List,
38    Array,
39    Dict,
40    Option,
41    Result,
42}
43
44impl BuiltinTypeId {
45    pub fn as_symbol(self) -> Symbol {
46        sym(self.as_str())
47    }
48
49    pub fn as_str(self) -> &'static str {
50        match self {
51            Self::U8 => "u8",
52            Self::U16 => "u16",
53            Self::U32 => "u32",
54            Self::U64 => "u64",
55            Self::I8 => "i8",
56            Self::I16 => "i16",
57            Self::I32 => "i32",
58            Self::I64 => "i64",
59            Self::F32 => "f32",
60            Self::F64 => "f64",
61            Self::Bool => "bool",
62            Self::String => "string",
63            Self::Uuid => "uuid",
64            Self::DateTime => "datetime",
65            Self::List => "List",
66            Self::Array => "Array",
67            Self::Dict => "Dict",
68            Self::Option => "Option",
69            Self::Result => "Result",
70        }
71    }
72
73    pub fn arity(self) -> usize {
74        match self {
75            Self::List | Self::Array | Self::Dict | Self::Option => 1,
76            Self::Result => 2,
77            _ => 0,
78        }
79    }
80
81    pub fn from_symbol(name: &Symbol) -> Option<Self> {
82        Self::from_name(name.as_ref())
83    }
84
85    pub fn from_name(name: &str) -> Option<Self> {
86        match name {
87            "u8" => Some(Self::U8),
88            "u16" => Some(Self::U16),
89            "u32" => Some(Self::U32),
90            "u64" => Some(Self::U64),
91            "i8" => Some(Self::I8),
92            "i16" => Some(Self::I16),
93            "i32" => Some(Self::I32),
94            "i64" => Some(Self::I64),
95            "f32" => Some(Self::F32),
96            "f64" => Some(Self::F64),
97            "bool" => Some(Self::Bool),
98            "string" => Some(Self::String),
99            "uuid" => Some(Self::Uuid),
100            "datetime" => Some(Self::DateTime),
101            "List" => Some(Self::List),
102            "Array" => Some(Self::Array),
103            "Dict" => Some(Self::Dict),
104            "Option" => Some(Self::Option),
105            "Result" => Some(Self::Result),
106            _ => None,
107        }
108    }
109}
110
111#[derive(Clone, Debug, Eq, Hash, PartialEq)]
112pub struct TypeVar {
113    pub id: TypeVarId,
114    pub name: Option<Symbol>,
115}
116
117impl TypeVar {
118    pub fn new(id: TypeVarId, name: impl Into<Option<Symbol>>) -> Self {
119        Self {
120            id,
121            name: name.into(),
122        }
123    }
124}
125
126#[derive(Clone, Debug, Eq, Hash, PartialEq)]
127pub struct TypeConst {
128    pub name: Symbol,
129    pub arity: usize,
130    pub builtin_id: Option<BuiltinTypeId>,
131}
132
133#[derive(Clone, Debug, PartialEq, Eq, Hash)]
134pub struct Type(Arc<TypeKind>);
135
136#[derive(Clone, Debug, PartialEq, Eq, Hash)]
137pub enum TypeKind {
138    Var(TypeVar),
139    Con(TypeConst),
140    App(Type, Type),
141    Fun(Type, Type),
142    Tuple(Vec<Type>),
143    /// Record type `{a: T, b: U}`.
144    ///
145    /// Invariant: fields are sorted by name. This makes record equality and
146    /// unification a cheap zip over two vectors, and it makes printing stable.
147    Record(Vec<(Symbol, Type)>),
148}
149
150impl Type {
151    pub fn new(kind: TypeKind) -> Self {
152        Type(Arc::new(kind))
153    }
154
155    pub fn con(name: impl AsRef<str>, arity: usize) -> Self {
156        if let Some(id) = BuiltinTypeId::from_name(name.as_ref())
157            && id.arity() == arity
158        {
159            return Self::builtin(id);
160        }
161        Self::user_con(name, arity)
162    }
163
164    pub fn user_con(name: impl AsRef<str>, arity: usize) -> Self {
165        Type::new(TypeKind::Con(TypeConst {
166            name: intern(name.as_ref()),
167            arity,
168            builtin_id: None,
169        }))
170    }
171
172    pub fn builtin(id: BuiltinTypeId) -> Self {
173        Type::new(TypeKind::Con(TypeConst {
174            name: id.as_symbol(),
175            arity: id.arity(),
176            builtin_id: Some(id),
177        }))
178    }
179
180    pub fn var(tv: TypeVar) -> Self {
181        Type::new(TypeKind::Var(tv))
182    }
183
184    pub fn fun(a: Type, b: Type) -> Self {
185        Type::new(TypeKind::Fun(a, b))
186    }
187
188    pub fn app(f: Type, arg: Type) -> Self {
189        Type::new(TypeKind::App(f, arg))
190    }
191
192    pub fn tuple(elems: Vec<Type>) -> Self {
193        Type::new(TypeKind::Tuple(elems))
194    }
195
196    pub fn record(mut fields: Vec<(Symbol, Type)>) -> Self {
197        // Canonicalize records so downstream code can rely on “same shape means
198        // same ordering”. (This is a correctness invariant, not a nicety.)
199        fields.sort_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
200        Type::new(TypeKind::Record(fields))
201    }
202
203    pub fn list(elem: Type) -> Type {
204        Type::app(Type::builtin(BuiltinTypeId::List), elem)
205    }
206
207    pub fn array(elem: Type) -> Type {
208        Type::app(Type::builtin(BuiltinTypeId::Array), elem)
209    }
210
211    pub fn dict(elem: Type) -> Type {
212        Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
213    }
214
215    pub fn option(elem: Type) -> Type {
216        Type::app(Type::builtin(BuiltinTypeId::Option), elem)
217    }
218
219    pub fn result(ok: Type, err: Type) -> Type {
220        Type::app(Type::app(Type::builtin(BuiltinTypeId::Result), err), ok)
221    }
222
223    fn apply_with_change(&self, s: &Subst) -> (Type, bool) {
224        match self.as_ref() {
225            TypeKind::Var(tv) => match s.get(&tv.id) {
226                Some(ty) => (ty.clone(), true),
227                None => (self.clone(), false),
228            },
229            TypeKind::Con(_) => (self.clone(), false),
230            TypeKind::App(l, r) => {
231                let (l_new, l_changed) = l.apply_with_change(s);
232                let (r_new, r_changed) = r.apply_with_change(s);
233                if l_changed || r_changed {
234                    (Type::app(l_new, r_new), true)
235                } else {
236                    (self.clone(), false)
237                }
238            }
239            TypeKind::Fun(_, _) => {
240                // Avoid recursive descent on long function chains like
241                // `a1 -> a2 -> ... -> an -> r`.
242                let mut args = Vec::new();
243                let mut changed = false;
244                let mut cur: &Type = self;
245                while let TypeKind::Fun(a, b) = cur.as_ref() {
246                    let (a_new, a_changed) = a.apply_with_change(s);
247                    changed |= a_changed;
248                    args.push(a_new);
249                    cur = b;
250                }
251                let (ret_new, ret_changed) = cur.apply_with_change(s);
252                changed |= ret_changed;
253                if !changed {
254                    return (self.clone(), false);
255                }
256                let mut out = ret_new;
257                for a_new in args.into_iter().rev() {
258                    out = Type::fun(a_new, out);
259                }
260                (out, true)
261            }
262            TypeKind::Tuple(ts) => {
263                let mut changed = false;
264                let mut out = Vec::with_capacity(ts.len());
265                for t in ts {
266                    let (t_new, t_changed) = t.apply_with_change(s);
267                    changed |= t_changed;
268                    out.push(t_new);
269                }
270                if changed {
271                    (Type::new(TypeKind::Tuple(out)), true)
272                } else {
273                    (self.clone(), false)
274                }
275            }
276            TypeKind::Record(fields) => {
277                let mut changed = false;
278                let mut out = Vec::with_capacity(fields.len());
279                for (k, v) in fields {
280                    let (v_new, v_changed) = v.apply_with_change(s);
281                    changed |= v_changed;
282                    out.push((k.clone(), v_new));
283                }
284                if changed {
285                    (Type::new(TypeKind::Record(out)), true)
286                } else {
287                    (self.clone(), false)
288                }
289            }
290        }
291    }
292}
293
294impl AsRef<TypeKind> for Type {
295    fn as_ref(&self) -> &TypeKind {
296        self.0.as_ref()
297    }
298}
299
300impl std::ops::Deref for Type {
301    type Target = TypeKind;
302
303    fn deref(&self) -> &Self::Target {
304        &self.0
305    }
306}
307
308impl Display for Type {
309    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
310        match self.as_ref() {
311            TypeKind::Var(tv) => match &tv.name {
312                Some(name) => write!(f, "'{}", name),
313                None => write!(f, "t{}", tv.id),
314            },
315            TypeKind::Con(c) => write!(f, "{}", c.name),
316            TypeKind::App(l, r) => {
317                // Internally `Result` is represented as `Result err ok` so it can be partially
318                // applied as `Result err` for HKTs (Functor/Monad/etc).
319                //
320                // User-facing syntax is `Result ok err` (Rust-style), so render the fully
321                // applied form with swapped arguments.
322                if let TypeKind::App(head, err) = l.as_ref()
323                    && matches!(
324                        head.as_ref(),
325                        TypeKind::Con(c)
326                            if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
327                    )
328                {
329                    return write!(f, "(Result {} {})", r, err);
330                }
331                write!(f, "({} {})", l, r)
332            }
333            TypeKind::Fun(a, b) => write!(f, "({} -> {})", a, b),
334            TypeKind::Tuple(elems) => {
335                write!(f, "(")?;
336                for (i, t) in elems.iter().enumerate() {
337                    write!(f, "{}", t)?;
338                    if i + 1 < elems.len() {
339                        write!(f, ", ")?;
340                    }
341                }
342                write!(f, ")")
343            }
344            TypeKind::Record(fields) => {
345                write!(f, "{{")?;
346                for (i, (name, ty)) in fields.iter().enumerate() {
347                    write!(f, "{}: {}", name, ty)?;
348                    if i + 1 < fields.len() {
349                        write!(f, ", ")?;
350                    }
351                }
352                write!(f, "}}")
353            }
354        }
355    }
356}
357
358#[derive(Clone, Debug, PartialEq, Eq, Hash)]
359pub struct Predicate {
360    pub class: Symbol,
361    pub typ: Type,
362}
363
364impl Predicate {
365    pub fn new(class: impl AsRef<str>, typ: Type) -> Self {
366        Self {
367            class: intern(class.as_ref()),
368            typ,
369        }
370    }
371}
372
373#[derive(Clone, Debug, PartialEq)]
374pub struct Scheme {
375    pub vars: Vec<TypeVar>,
376    pub preds: Vec<Predicate>,
377    pub typ: Type,
378}
379
380impl Scheme {
381    pub fn new(vars: Vec<TypeVar>, preds: Vec<Predicate>, typ: Type) -> Self {
382        Self { vars, preds, typ }
383    }
384}
385
386pub type Subst = HashTrieMapSync<TypeVarId, Type>;
387
388pub trait Types: Sized {
389    fn apply(&self, s: &Subst) -> Self;
390    fn ftv(&self) -> HashSet<TypeVarId>;
391}
392
393impl Types for Type {
394    fn apply(&self, s: &Subst) -> Self {
395        self.apply_with_change(s).0
396    }
397
398    fn ftv(&self) -> HashSet<TypeVarId> {
399        let mut out = HashSet::new();
400        let mut stack: Vec<&Type> = vec![self];
401        while let Some(t) = stack.pop() {
402            match t.as_ref() {
403                TypeKind::Var(tv) => {
404                    out.insert(tv.id);
405                }
406                TypeKind::Con(_) => {}
407                TypeKind::App(l, r) => {
408                    stack.push(l);
409                    stack.push(r);
410                }
411                TypeKind::Fun(a, b) => {
412                    stack.push(a);
413                    stack.push(b);
414                }
415                TypeKind::Tuple(ts) => {
416                    for t in ts {
417                        stack.push(t);
418                    }
419                }
420                TypeKind::Record(fields) => {
421                    for (_, ty) in fields {
422                        stack.push(ty);
423                    }
424                }
425            }
426        }
427        out
428    }
429}
430
431impl Types for Predicate {
432    fn apply(&self, s: &Subst) -> Self {
433        Predicate {
434            class: self.class.clone(),
435            typ: self.typ.apply(s),
436        }
437    }
438
439    fn ftv(&self) -> HashSet<TypeVarId> {
440        self.typ.ftv()
441    }
442}
443
444impl Types for Scheme {
445    fn apply(&self, s: &Subst) -> Self {
446        let mut s_pruned = Subst::new_sync();
447        for (k, v) in s.iter() {
448            if !self.vars.iter().any(|var| var.id == *k) {
449                s_pruned = s_pruned.insert(*k, v.clone());
450            }
451        }
452        Scheme::new(
453            self.vars.clone(),
454            self.preds.iter().map(|p| p.apply(&s_pruned)).collect(),
455            self.typ.apply(&s_pruned),
456        )
457    }
458
459    fn ftv(&self) -> HashSet<TypeVarId> {
460        let mut ftv = self.typ.ftv();
461        for p in &self.preds {
462            ftv.extend(p.ftv());
463        }
464        for v in &self.vars {
465            ftv.remove(&v.id);
466        }
467        ftv
468    }
469}
470
471impl<T: Types> Types for Vec<T> {
472    fn apply(&self, s: &Subst) -> Self {
473        self.iter().map(|t| t.apply(s)).collect()
474    }
475
476    fn ftv(&self) -> HashSet<TypeVarId> {
477        self.iter().flat_map(Types::ftv).collect()
478    }
479}
480
481#[derive(Clone, Debug, PartialEq)]
482pub struct TypedExpr {
483    pub typ: Type,
484    pub kind: TypedExprKind,
485}
486
487impl TypedExpr {
488    pub fn new(typ: Type, kind: TypedExprKind) -> Self {
489        Self { typ, kind }
490    }
491
492    pub fn apply(&self, s: &Subst) -> Self {
493        match &self.kind {
494            TypedExprKind::Lam { .. } => {
495                let mut params: Vec<(Symbol, Type)> = Vec::new();
496                let mut cur = self;
497                while let TypedExprKind::Lam { param, body } = &cur.kind {
498                    params.push((param.clone(), cur.typ.apply(s)));
499                    cur = body.as_ref();
500                }
501                let mut out = cur.apply(s);
502                for (param, typ) in params.into_iter().rev() {
503                    out = TypedExpr {
504                        typ,
505                        kind: TypedExprKind::Lam {
506                            param,
507                            body: Box::new(out),
508                        },
509                    };
510                }
511                return out;
512            }
513            TypedExprKind::App(..) => {
514                let mut apps: Vec<(Type, &TypedExpr)> = Vec::new();
515                let mut cur = self;
516                while let TypedExprKind::App(f, x) = &cur.kind {
517                    apps.push((cur.typ.apply(s), x.as_ref()));
518                    cur = f.as_ref();
519                }
520                let mut out = cur.apply(s);
521                for (typ, arg) in apps.into_iter().rev() {
522                    out = TypedExpr {
523                        typ,
524                        kind: TypedExprKind::App(Box::new(out), Box::new(arg.apply(s))),
525                    };
526                }
527                return out;
528            }
529            _ => {}
530        }
531
532        let typ = self.typ.apply(s);
533        let kind = match &self.kind {
534            TypedExprKind::Bool(v) => TypedExprKind::Bool(*v),
535            TypedExprKind::Uint(v) => TypedExprKind::Uint(*v),
536            TypedExprKind::Int(v) => TypedExprKind::Int(*v),
537            TypedExprKind::Float(v) => TypedExprKind::Float(*v),
538            TypedExprKind::String(v) => TypedExprKind::String(v.clone()),
539            TypedExprKind::Uuid(v) => TypedExprKind::Uuid(*v),
540            TypedExprKind::DateTime(v) => TypedExprKind::DateTime(*v),
541            TypedExprKind::Hole => TypedExprKind::Hole,
542            TypedExprKind::Tuple(elems) => {
543                TypedExprKind::Tuple(elems.iter().map(|e| e.apply(s)).collect())
544            }
545            TypedExprKind::List(elems) => {
546                TypedExprKind::List(elems.iter().map(|e| e.apply(s)).collect())
547            }
548            TypedExprKind::Dict(kvs) => {
549                let mut out = BTreeMap::new();
550                for (k, v) in kvs {
551                    out.insert(k.clone(), v.apply(s));
552                }
553                TypedExprKind::Dict(out)
554            }
555            TypedExprKind::RecordUpdate { base, updates } => {
556                let mut out = BTreeMap::new();
557                for (k, v) in updates {
558                    out.insert(k.clone(), v.apply(s));
559                }
560                TypedExprKind::RecordUpdate {
561                    base: Box::new(base.apply(s)),
562                    updates: out,
563                }
564            }
565            TypedExprKind::Var { name, overloads } => TypedExprKind::Var {
566                name: name.clone(),
567                overloads: overloads.iter().map(|t| t.apply(s)).collect(),
568            },
569            TypedExprKind::App(f, x) => {
570                TypedExprKind::App(Box::new(f.apply(s)), Box::new(x.apply(s)))
571            }
572            TypedExprKind::Project { expr, field } => TypedExprKind::Project {
573                expr: Box::new(expr.apply(s)),
574                field: field.clone(),
575            },
576            TypedExprKind::Lam { param, body } => TypedExprKind::Lam {
577                param: param.clone(),
578                body: Box::new(body.apply(s)),
579            },
580            TypedExprKind::Let { name, def, body } => TypedExprKind::Let {
581                name: name.clone(),
582                def: Box::new(def.apply(s)),
583                body: Box::new(body.apply(s)),
584            },
585            TypedExprKind::LetRec { bindings, body } => TypedExprKind::LetRec {
586                bindings: bindings
587                    .iter()
588                    .map(|(name, def)| (name.clone(), def.apply(s)))
589                    .collect(),
590                body: Box::new(body.apply(s)),
591            },
592            TypedExprKind::Ite {
593                cond,
594                then_expr,
595                else_expr,
596            } => TypedExprKind::Ite {
597                cond: Box::new(cond.apply(s)),
598                then_expr: Box::new(then_expr.apply(s)),
599                else_expr: Box::new(else_expr.apply(s)),
600            },
601            TypedExprKind::Match { scrutinee, arms } => TypedExprKind::Match {
602                scrutinee: Box::new(scrutinee.apply(s)),
603                arms: arms.iter().map(|(p, e)| (p.clone(), e.apply(s))).collect(),
604            },
605        };
606        TypedExpr { typ, kind }
607    }
608}
609
610#[derive(Clone, Debug, PartialEq)]
611pub enum TypedExprKind {
612    Bool(bool),
613    Uint(u64),
614    Int(i64),
615    Float(f64),
616    String(String),
617    Uuid(Uuid),
618    DateTime(DateTime<Utc>),
619    Hole,
620    Tuple(Vec<TypedExpr>),
621    List(Vec<TypedExpr>),
622    Dict(BTreeMap<Symbol, TypedExpr>),
623    RecordUpdate {
624        base: Box<TypedExpr>,
625        updates: BTreeMap<Symbol, TypedExpr>,
626    },
627    Var {
628        name: Symbol,
629        overloads: Vec<Type>,
630    },
631    App(Box<TypedExpr>, Box<TypedExpr>),
632    Project {
633        expr: Box<TypedExpr>,
634        field: Symbol,
635    },
636    Lam {
637        param: Symbol,
638        body: Box<TypedExpr>,
639    },
640    Let {
641        name: Symbol,
642        def: Box<TypedExpr>,
643        body: Box<TypedExpr>,
644    },
645    LetRec {
646        bindings: Vec<(Symbol, TypedExpr)>,
647        body: Box<TypedExpr>,
648    },
649    Ite {
650        cond: Box<TypedExpr>,
651        then_expr: Box<TypedExpr>,
652        else_expr: Box<TypedExpr>,
653    },
654    Match {
655        scrutinee: Box<TypedExpr>,
656        arms: Vec<(Pattern, TypedExpr)>,
657    },
658}
659
660/// Compose substitutions `a` after `b`.
661///
662/// If `t.apply(&b)` is “apply `b` first”, then:
663/// `t.apply(&compose_subst(a, b)) == t.apply(&b).apply(&a)`.
664pub fn compose_subst(a: Subst, b: Subst) -> Subst {
665    if subst_is_empty(&a) {
666        return b;
667    }
668    if subst_is_empty(&b) {
669        return a;
670    }
671    let mut res = Subst::new_sync();
672    for (k, v) in b.iter() {
673        res = res.insert(*k, v.apply(&a));
674    }
675    for (k, v) in a.iter() {
676        res = res.insert(*k, v.clone());
677    }
678    res
679}
680
681fn subst_is_empty(s: &Subst) -> bool {
682    s.iter().next().is_none()
683}
684
685fn dedup_preds(preds: Vec<Predicate>) -> Vec<Predicate> {
686    let mut seen = HashSet::new();
687    let mut out = Vec::with_capacity(preds.len());
688    for pred in preds {
689        if seen.insert(pred.clone()) {
690            out.push(pred);
691        }
692    }
693    out
694}
695
696fn is_integral_primitive(typ: &Type) -> bool {
697    matches!(
698        typ.as_ref(),
699        TypeKind::Con(TypeConst {
700            builtin_id: Some(
701                BuiltinTypeId::U8
702                    | BuiltinTypeId::U16
703                    | BuiltinTypeId::U32
704                    | BuiltinTypeId::U64
705                    | BuiltinTypeId::I8
706                    | BuiltinTypeId::I16
707                    | BuiltinTypeId::I32
708                    | BuiltinTypeId::I64
709            ),
710            ..
711        })
712    )
713}
714
715fn finalize_infer_for_public_api(
716    mut preds: Vec<Predicate>,
717    mut typ: Type,
718) -> Result<(Vec<Predicate>, Type), TypeError> {
719    let mut subst = Subst::new_sync();
720    for pred in &preds {
721        if pred.class.as_ref() == "Integral"
722            && let TypeKind::Var(tv) = pred.typ.as_ref()
723        {
724            subst = subst.insert(tv.id, Type::builtin(BuiltinTypeId::I32));
725        }
726    }
727
728    if !subst_is_empty(&subst) {
729        preds = dedup_preds(preds.apply(&subst));
730        typ = typ.apply(&subst);
731    }
732
733    for pred in &preds {
734        if pred.class.as_ref() != "Integral" {
735            continue;
736        }
737        if matches!(pred.typ.as_ref(), TypeKind::Var(_)) || is_integral_primitive(&pred.typ) {
738            continue;
739        }
740        return Err(TypeError::Unification("i32".into(), pred.typ.to_string()));
741    }
742
743    Ok((preds, typ))
744}
745
746#[derive(Debug, thiserror::Error, PartialEq, Eq)]
747pub enum TypeError {
748    #[error("types do not unify: {0} vs {1}")]
749    Unification(String, String),
750    #[error("occurs check failed for {0} in {1}")]
751    Occurs(TypeVarId, String),
752    #[error("unknown class {0}")]
753    UnknownClass(Symbol),
754    #[error("no instance for {0} {1}")]
755    NoInstance(Symbol, String),
756    #[error("unknown type {0}")]
757    UnknownTypeName(Symbol),
758    #[error("cannot redefine reserved builtin type `{0}`")]
759    ReservedTypeName(Symbol),
760    #[error("duplicate value definition `{0}`")]
761    DuplicateValue(Symbol),
762    #[error("duplicate class definition `{0}`")]
763    DuplicateClass(Symbol),
764    #[error("class `{class}` must have at least one type parameter (got {got})")]
765    InvalidClassArity { class: Symbol, got: usize },
766    #[error("duplicate class method `{0}`")]
767    DuplicateClassMethod(Symbol),
768    #[error("unknown method `{method}` in instance of class `{class}`")]
769    UnknownInstanceMethod { class: Symbol, method: Symbol },
770    #[error("missing implementation of `{method}` for instance of class `{class}`")]
771    MissingInstanceMethod { class: Symbol, method: Symbol },
772    #[error(
773        "instance method `{method}` requires constraint {class} {typ}, but it is not in the instance context"
774    )]
775    MissingInstanceConstraint {
776        method: Symbol,
777        class: Symbol,
778        typ: String,
779    },
780    #[error("unbound variable {0}")]
781    UnknownVar(Symbol),
782    #[error("ambiguous overload for {0}")]
783    AmbiguousOverload(Symbol),
784    #[error("ambiguous type variable(s) {vars:?} in constraints: {constraints}")]
785    AmbiguousTypeVars {
786        vars: Vec<TypeVarId>,
787        constraints: String,
788    },
789    #[error(
790        "kind mismatch for class `{class}`: expected {expected} type argument(s) remaining, got {got} for {typ}"
791    )]
792    KindMismatch {
793        class: Symbol,
794        expected: usize,
795        got: usize,
796        typ: String,
797    },
798    #[error("missing type class constraint(s): {constraints}")]
799    MissingConstraints { constraints: String },
800    #[error("unsupported expression {0}")]
801    UnsupportedExpr(&'static str),
802    #[error("unknown field `{field}` on {typ}")]
803    UnknownField { field: Symbol, typ: String },
804    #[error("field `{field}` is not definitely available on {typ}")]
805    FieldNotKnown { field: Symbol, typ: String },
806    #[error("non-exhaustive match for {typ}: missing {missing:?}")]
807    NonExhaustiveMatch { typ: String, missing: Vec<Symbol> },
808    #[error("at {span}: {error}")]
809    Spanned { span: Span, error: Box<TypeError> },
810    #[error("internal error: {0}")]
811    Internal(String),
812    #[error("{0}")]
813    OutOfGas(#[from] OutOfGas),
814}
815
816fn with_span(span: &Span, err: TypeError) -> TypeError {
817    match err {
818        TypeError::Spanned { .. } => err,
819        other => TypeError::Spanned {
820            span: *span,
821            error: Box::new(other),
822        },
823    }
824}
825
826fn format_constraints_referencing_vars(preds: &[Predicate], vars: &[TypeVarId]) -> String {
827    if vars.is_empty() {
828        return String::new();
829    }
830    let var_set: HashSet<TypeVarId> = vars.iter().copied().collect();
831    let mut parts = Vec::new();
832    for pred in preds {
833        let ftv = pred.ftv();
834        if ftv.iter().any(|v| var_set.contains(v)) {
835            parts.push(format!("{} {}", pred.class, pred.typ));
836        }
837    }
838    if parts.is_empty() {
839        // Fallback: show all constraints if the filtering logic misses something.
840        for pred in preds {
841            parts.push(format!("{} {}", pred.class, pred.typ));
842        }
843    }
844    parts.join(", ")
845}
846
847fn reject_ambiguous_scheme(scheme: &Scheme) -> Result<(), TypeError> {
848    // Only reject *quantified* ambiguous variables. Variables free in the
849    // environment are allowed to appear only in predicates, since they can be
850    // determined by outer context.
851    let quantified: HashSet<TypeVarId> = scheme.vars.iter().map(|v| v.id).collect();
852    if quantified.is_empty() {
853        return Ok(());
854    }
855
856    let typ_ftv = scheme.typ.ftv();
857    let mut vars = HashSet::new();
858    for pred in &scheme.preds {
859        let TypeKind::Var(tv) = pred.typ.as_ref() else {
860            continue;
861        };
862        if quantified.contains(&tv.id) && !typ_ftv.contains(&tv.id) {
863            vars.insert(tv.id);
864        }
865    }
866
867    if vars.is_empty() {
868        return Ok(());
869    }
870    let mut vars: Vec<TypeVarId> = vars.into_iter().collect();
871    vars.sort_unstable();
872    let constraints = format_constraints_referencing_vars(&scheme.preds, &vars);
873    Err(TypeError::AmbiguousTypeVars { vars, constraints })
874}
875
876fn scheme_compatible(existing: &Scheme, declared: &Scheme) -> bool {
877    let s = match unify(&existing.typ, &declared.typ) {
878        Ok(s) => s,
879        Err(_) => return false,
880    };
881
882    let existing_preds = existing.preds.apply(&s);
883    let declared_preds = declared.preds.apply(&s);
884
885    let mut lhs: Vec<(Symbol, String)> = existing_preds
886        .iter()
887        .map(|p| (p.class.clone(), p.typ.to_string()))
888        .collect();
889    let mut rhs: Vec<(Symbol, String)> = declared_preds
890        .iter()
891        .map(|p| (p.class.clone(), p.typ.to_string()))
892        .collect();
893    lhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
894    rhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
895    lhs == rhs
896}
897
898#[derive(Debug)]
899struct Unifier<'g> {
900    // `subs[id] = Some(t)` means type variable `id` has been bound to `t`.
901    //
902    // This is intentionally a dense `Vec` rather than a `HashMap`: inference
903    // generates `TypeVarId`s from a monotonic counter, so the common case is
904    // “small id space, lots of lookups”. This makes the cost model obvious:
905    // you pay O(max_id) space, and you get O(1) binds/queries.
906    subs: Vec<Option<Type>>,
907    gas: Option<&'g mut GasMeter>,
908    max_infer_depth: Option<usize>,
909    infer_depth: usize,
910}
911
912#[derive(Clone, Copy, Debug)]
913pub struct TypeSystemLimits {
914    pub max_infer_depth: Option<usize>,
915}
916
917impl TypeSystemLimits {
918    pub fn unlimited() -> Self {
919        Self {
920            max_infer_depth: None,
921        }
922    }
923
924    pub fn safe_defaults() -> Self {
925        Self {
926            max_infer_depth: Some(4096),
927        }
928    }
929}
930
931impl Default for TypeSystemLimits {
932    fn default() -> Self {
933        Self::safe_defaults()
934    }
935}
936
937fn superclass_closure(class_env: &ClassEnv, given: &[Predicate]) -> Vec<Predicate> {
938    let mut closure: Vec<Predicate> = given.to_vec();
939    let mut i = 0;
940    while i < closure.len() {
941        let p = closure[i].clone();
942        for sup in class_env.supers_of(&p.class) {
943            closure.push(Predicate::new(sup, p.typ.clone()));
944        }
945        i += 1;
946    }
947    closure
948}
949
950fn check_non_ground_predicates_declared(
951    class_env: &ClassEnv,
952    declared: &[Predicate],
953    inferred: &[Predicate],
954) -> Result<(), TypeError> {
955    // Compare by a stable, user-facing rendering (`Default a`, `Foldable t`, ...),
956    // rather than `TypeVarId`, so signature variables that only appear in
957    // predicates (and thus aren't related by unification) still match up.
958    let closure = superclass_closure(class_env, declared);
959    let closure_keys: HashSet<String> = closure
960        .iter()
961        .map(|p| format!("{} {}", p.class, p.typ))
962        .collect();
963    let mut missing = Vec::new();
964    for pred in inferred {
965        if pred.typ.ftv().is_empty() {
966            continue;
967        }
968        let key = format!("{} {}", pred.class, pred.typ);
969        if !closure_keys.contains(&key) {
970            missing.push(key);
971        }
972    }
973
974    missing.sort();
975    missing.dedup();
976    if missing.is_empty() {
977        return Ok(());
978    }
979    Err(TypeError::MissingConstraints {
980        constraints: missing.join(", "),
981    })
982}
983
984fn type_term_remaining_arity(ty: &Type) -> Option<usize> {
985    match ty.as_ref() {
986        TypeKind::Var(_) => None,
987        TypeKind::Con(tc) => Some(tc.arity),
988        TypeKind::App(l, _) => {
989            let a = type_term_remaining_arity(l)?;
990            Some(a.saturating_sub(1))
991        }
992        TypeKind::Fun(..) | TypeKind::Tuple(..) | TypeKind::Record(..) => Some(0),
993    }
994}
995
996fn max_head_app_arity_for_var(ty: &Type, var_id: TypeVarId) -> usize {
997    let mut max_arity = 0usize;
998    let mut stack: Vec<&Type> = vec![ty];
999    while let Some(t) = stack.pop() {
1000        match t.as_ref() {
1001            TypeKind::Var(_) | TypeKind::Con(_) => {}
1002            TypeKind::App(l, r) => {
1003                // Record the full application depth at this node.
1004                let mut head = t;
1005                let mut args = 0usize;
1006                while let TypeKind::App(left, _) = head.as_ref() {
1007                    args += 1;
1008                    head = left;
1009                }
1010                if let TypeKind::Var(tv) = head.as_ref()
1011                    && tv.id == var_id
1012                {
1013                    max_arity = max_arity.max(args);
1014                }
1015                stack.push(l);
1016                stack.push(r);
1017            }
1018            TypeKind::Fun(a, b) => {
1019                stack.push(a);
1020                stack.push(b);
1021            }
1022            TypeKind::Tuple(ts) => {
1023                for t in ts {
1024                    stack.push(t);
1025                }
1026            }
1027            TypeKind::Record(fields) => {
1028                for (_, t) in fields {
1029                    stack.push(t);
1030                }
1031            }
1032        }
1033    }
1034    max_arity
1035}
1036
1037impl<'g> Unifier<'g> {
1038    fn new(max_infer_depth: Option<usize>) -> Self {
1039        Self {
1040            subs: Vec::new(),
1041            gas: None,
1042            max_infer_depth,
1043            infer_depth: 0,
1044        }
1045    }
1046
1047    fn with_gas(gas: &'g mut GasMeter, max_infer_depth: Option<usize>) -> Self {
1048        Self {
1049            subs: Vec::new(),
1050            gas: Some(gas),
1051            max_infer_depth,
1052            infer_depth: 0,
1053        }
1054    }
1055
1056    fn with_infer_depth<T>(
1057        &mut self,
1058        span: Span,
1059        f: impl FnOnce(&mut Self) -> Result<T, TypeError>,
1060    ) -> Result<T, TypeError> {
1061        if let Some(max) = self.max_infer_depth
1062            && self.infer_depth >= max
1063        {
1064            return Err(TypeError::Spanned {
1065                span,
1066                error: Box::new(TypeError::Internal(format!(
1067                    "maximum inference depth exceeded (max {max})"
1068                ))),
1069            });
1070        }
1071        self.infer_depth += 1;
1072        let res = f(self);
1073        self.infer_depth = self.infer_depth.saturating_sub(1);
1074        res
1075    }
1076
1077    fn charge_infer_node(&mut self) -> Result<(), TypeError> {
1078        let Some(gas) = self.gas.as_mut() else {
1079            return Ok(());
1080        };
1081        let cost = gas.costs.infer_node;
1082        gas.charge(cost)?;
1083        Ok(())
1084    }
1085
1086    fn charge_unify_step(&mut self) -> Result<(), TypeError> {
1087        let Some(gas) = self.gas.as_mut() else {
1088            return Ok(());
1089        };
1090        let cost = gas.costs.unify_step;
1091        gas.charge(cost)?;
1092        Ok(())
1093    }
1094
1095    fn bind_var(&mut self, id: TypeVarId, ty: Type) {
1096        if id >= self.subs.len() {
1097            self.subs.resize(id + 1, None);
1098        }
1099        self.subs[id] = Some(ty);
1100    }
1101
1102    fn prune(&mut self, ty: &Type) -> Type {
1103        match ty.as_ref() {
1104            TypeKind::Var(tv) => {
1105                let bound = self.subs.get(tv.id).and_then(|t| t.clone());
1106                match bound {
1107                    Some(bound) => {
1108                        let pruned = self.prune(&bound);
1109                        self.bind_var(tv.id, pruned.clone());
1110                        pruned
1111                    }
1112                    None => ty.clone(),
1113                }
1114            }
1115            TypeKind::Con(_) => ty.clone(),
1116            TypeKind::App(l, r) => {
1117                let l = self.prune(l);
1118                let r = self.prune(r);
1119                Type::app(l, r)
1120            }
1121            TypeKind::Fun(a, b) => {
1122                let a = self.prune(a);
1123                let b = self.prune(b);
1124                Type::fun(a, b)
1125            }
1126            TypeKind::Tuple(ts) => {
1127                Type::new(TypeKind::Tuple(ts.iter().map(|t| self.prune(t)).collect()))
1128            }
1129            TypeKind::Record(fields) => Type::new(TypeKind::Record(
1130                fields
1131                    .iter()
1132                    .map(|(name, ty)| (name.clone(), self.prune(ty)))
1133                    .collect(),
1134            )),
1135        }
1136    }
1137
1138    fn apply_type(&mut self, ty: &Type) -> Type {
1139        self.prune(ty)
1140    }
1141
1142    fn occurs(&mut self, id: TypeVarId, ty: &Type) -> bool {
1143        match self.prune(ty).as_ref() {
1144            TypeKind::Var(tv) => tv.id == id,
1145            TypeKind::Con(_) => false,
1146            TypeKind::App(l, r) => self.occurs(id, l) || self.occurs(id, r),
1147            TypeKind::Fun(a, b) => self.occurs(id, a) || self.occurs(id, b),
1148            TypeKind::Tuple(ts) => ts.iter().any(|t| self.occurs(id, t)),
1149            TypeKind::Record(fields) => fields.iter().any(|(_, ty)| self.occurs(id, ty)),
1150        }
1151    }
1152
1153    fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
1154        self.charge_unify_step()?;
1155        let t1 = self.prune(t1);
1156        let t2 = self.prune(t2);
1157        match (t1.as_ref(), t2.as_ref()) {
1158            (TypeKind::Var(a), TypeKind::Var(b)) if a.id == b.id => Ok(()),
1159            (TypeKind::Var(tv), other) | (other, TypeKind::Var(tv)) => {
1160                if self.occurs(tv.id, &Type::new(other.clone())) {
1161                    Err(TypeError::Occurs(
1162                        tv.id,
1163                        Type::new(other.clone()).to_string(),
1164                    ))
1165                } else {
1166                    self.bind_var(tv.id, Type::new(other.clone()));
1167                    Ok(())
1168                }
1169            }
1170            (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(()),
1171            (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1172                self.unify(l1, l2)?;
1173                self.unify(r1, r2)
1174            }
1175            (TypeKind::Fun(a1, b1), TypeKind::Fun(a2, b2)) => {
1176                self.unify(a1, a2)?;
1177                self.unify(b1, b2)
1178            }
1179            (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1180                if ts1.len() != ts2.len() {
1181                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1182                }
1183                for (a, b) in ts1.iter().zip(ts2.iter()) {
1184                    self.unify(a, b)?;
1185                }
1186                Ok(())
1187            }
1188            (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1189                if f1.len() != f2.len() {
1190                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1191                }
1192                for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1193                    if n1 != n2 {
1194                        return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1195                    }
1196                    self.unify(t1, t2)?;
1197                }
1198                Ok(())
1199            }
1200            (TypeKind::Record(fields), TypeKind::App(head, arg))
1201            | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1202                TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1203                    let elem_ty = record_elem_type_unifier(fields, self)?;
1204                    self.unify(arg, &elem_ty)
1205                }
1206                TypeKind::Var(tv) => {
1207                    self.unify(
1208                        &Type::new(TypeKind::Var(tv.clone())),
1209                        &Type::builtin(BuiltinTypeId::Dict),
1210                    )?;
1211                    let elem_ty = record_elem_type_unifier(fields, self)?;
1212                    self.unify(arg, &elem_ty)
1213                }
1214                _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1215            },
1216            _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1217        }
1218    }
1219
1220    fn into_subst(mut self) -> Subst {
1221        let mut out = Subst::new_sync();
1222        for id in 0..self.subs.len() {
1223            if let Some(ty) = self.subs[id].clone() {
1224                let pruned = self.prune(&ty);
1225                out = out.insert(id, pruned);
1226            }
1227        }
1228        out
1229    }
1230}
1231
1232fn record_elem_type_unifier(
1233    fields: &[(Symbol, Type)],
1234    unifier: &mut Unifier<'_>,
1235) -> Result<Type, TypeError> {
1236    let mut iter = fields.iter();
1237    let first = match iter.next() {
1238        Some((_, ty)) => ty.clone(),
1239        None => return Err(TypeError::UnsupportedExpr("empty record")),
1240    };
1241    for (_, ty) in iter {
1242        unifier.unify(&first, ty)?;
1243    }
1244    Ok(unifier.apply_type(&first))
1245}
1246
1247fn bind(tv: &TypeVar, t: &Type) -> Result<Subst, TypeError> {
1248    if let TypeKind::Var(var) = t.as_ref()
1249        && var.id == tv.id
1250    {
1251        return Ok(Subst::new_sync());
1252    }
1253    if t.ftv().contains(&tv.id) {
1254        Err(TypeError::Occurs(tv.id, t.to_string()))
1255    } else {
1256        Ok(Subst::new_sync().insert(tv.id, t.clone()))
1257    }
1258}
1259
1260fn record_elem_type(fields: &[(Symbol, Type)]) -> Result<(Subst, Type), TypeError> {
1261    let mut iter = fields.iter();
1262    let first = match iter.next() {
1263        Some((_, ty)) => ty.clone(),
1264        None => return Err(TypeError::UnsupportedExpr("empty record")),
1265    };
1266    let mut subst = Subst::new_sync();
1267    let mut current = first;
1268    for (_, ty) in iter {
1269        let s_next = unify(&current.apply(&subst), &ty.apply(&subst))?;
1270        subst = compose_subst(s_next, subst);
1271        current = current.apply(&subst);
1272    }
1273    Ok((subst.clone(), current.apply(&subst)))
1274}
1275
1276/// Compute a most-general unifier for two types.
1277///
1278/// This is the “pure” unifier: it returns an explicit substitution map and is
1279/// easy to read/compose in isolation. The type inference engine uses `Unifier`
1280/// directly to avoid allocating and composing persistent maps at every
1281/// unification step.
1282pub fn unify(t1: &Type, t2: &Type) -> Result<Subst, TypeError> {
1283    match (t1.as_ref(), t2.as_ref()) {
1284        (TypeKind::Fun(l1, r1), TypeKind::Fun(l2, r2)) => {
1285            let s1 = unify(l1, l2)?;
1286            let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1287            Ok(compose_subst(s2, s1))
1288        }
1289        (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1290            if f1.len() != f2.len() {
1291                return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1292            }
1293            let mut subst = Subst::new_sync();
1294            for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1295                if n1 != n2 {
1296                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1297                }
1298                let s_next = unify(&t1.apply(&subst), &t2.apply(&subst))?;
1299                subst = compose_subst(s_next, subst);
1300            }
1301            Ok(subst)
1302        }
1303        (TypeKind::Record(fields), TypeKind::App(head, arg))
1304        | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1305            TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1306                let (s_fields, elem_ty) = record_elem_type(fields)?;
1307                let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1308                Ok(compose_subst(s_arg, s_fields))
1309            }
1310            TypeKind::Var(tv) => {
1311                let s_head = bind(tv, &Type::builtin(BuiltinTypeId::Dict))?;
1312                let arg = arg.apply(&s_head);
1313                let (s_fields, elem_ty) = record_elem_type(fields)?;
1314                let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1315                Ok(compose_subst(s_arg, compose_subst(s_fields, s_head)))
1316            }
1317            _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1318        },
1319        (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1320            let s1 = unify(l1, l2)?;
1321            let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1322            Ok(compose_subst(s2, s1))
1323        }
1324        (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1325            if ts1.len() != ts2.len() {
1326                return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1327            }
1328            let mut s = Subst::new_sync();
1329            for (a, b) in ts1.iter().zip(ts2.iter()) {
1330                let s_next = unify(&a.apply(&s), &b.apply(&s))?;
1331                s = compose_subst(s_next, s);
1332            }
1333            Ok(s)
1334        }
1335        (TypeKind::Var(tv), t) | (t, TypeKind::Var(tv)) => bind(tv, &Type::new(t.clone())),
1336        (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(Subst::new_sync()),
1337        _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1338    }
1339}
1340
1341#[derive(Default, Debug, Clone)]
1342pub struct TypeEnv {
1343    pub values: HashTrieMapSync<Symbol, Vec<Scheme>>,
1344}
1345
1346impl TypeEnv {
1347    pub fn new() -> Self {
1348        Self {
1349            values: HashTrieMapSync::new_sync(),
1350        }
1351    }
1352
1353    pub fn extend(&mut self, name: Symbol, scheme: Scheme) {
1354        self.values = self.values.insert(name, vec![scheme]);
1355    }
1356
1357    pub fn extend_overload(&mut self, name: Symbol, scheme: Scheme) {
1358        let mut schemes = self.values.get(&name).cloned().unwrap_or_default();
1359        schemes.push(scheme);
1360        self.values = self.values.insert(name, schemes);
1361    }
1362
1363    pub fn remove(&mut self, name: &Symbol) {
1364        self.values = self.values.remove(name);
1365    }
1366
1367    pub fn lookup(&self, name: &Symbol) -> Option<&[Scheme]> {
1368        self.values.get(name).map(|schemes| schemes.as_slice())
1369    }
1370}
1371
1372impl Types for TypeEnv {
1373    fn apply(&self, s: &Subst) -> Self {
1374        let mut values = HashTrieMapSync::new_sync();
1375        for (k, v) in self.values.iter() {
1376            let updated = v
1377                .iter()
1378                .map(|scheme| {
1379                    // Most schemes in environments are monomorphic. Don't walk
1380                    // and rebuild trees unless we actually have work to do.
1381                    if scheme.vars.is_empty() && !subst_is_empty(s) {
1382                        scheme.apply(s)
1383                    } else {
1384                        scheme.clone()
1385                    }
1386                })
1387                .collect();
1388            values = values.insert(k.clone(), updated);
1389        }
1390        TypeEnv { values }
1391    }
1392
1393    fn ftv(&self) -> HashSet<TypeVarId> {
1394        self.values
1395            .iter()
1396            .flat_map(|(_, schemes)| schemes.iter().flat_map(Types::ftv))
1397            .collect()
1398    }
1399}
1400
1401#[derive(Clone, Debug)]
1402struct KnownVariant {
1403    adt: Symbol,
1404    variant: Symbol,
1405}
1406
1407type KnownVariants = HashMap<Symbol, KnownVariant>;
1408
1409#[derive(Default, Debug, Clone)]
1410pub struct TypeVarSupply {
1411    counter: TypeVarId,
1412}
1413
1414impl TypeVarSupply {
1415    pub fn new() -> Self {
1416        Self { counter: 0 }
1417    }
1418
1419    pub fn fresh(&mut self, name_hint: impl Into<Option<Symbol>>) -> TypeVar {
1420        let tv = TypeVar::new(self.counter, name_hint.into());
1421        self.counter += 1;
1422        tv
1423    }
1424}
1425
1426fn apply_scheme_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> Scheme {
1427    let preds = scheme
1428        .preds
1429        .iter()
1430        .map(|pred| Predicate::new(pred.class.clone(), unifier.apply_type(&pred.typ)))
1431        .collect();
1432    let typ = unifier.apply_type(&scheme.typ);
1433    Scheme::new(scheme.vars.clone(), preds, typ)
1434}
1435
1436fn scheme_ftv_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> HashSet<TypeVarId> {
1437    let mut ftv = unifier.apply_type(&scheme.typ).ftv();
1438    for pred in &scheme.preds {
1439        ftv.extend(unifier.apply_type(&pred.typ).ftv());
1440    }
1441    for var in &scheme.vars {
1442        ftv.remove(&var.id);
1443    }
1444    ftv
1445}
1446
1447fn env_ftv_with_unifier(env: &TypeEnv, unifier: &mut Unifier<'_>) -> HashSet<TypeVarId> {
1448    let mut out = HashSet::new();
1449    for (_name, schemes) in env.values.iter() {
1450        for scheme in schemes {
1451            out.extend(scheme_ftv_with_unifier(scheme, unifier));
1452        }
1453    }
1454    out
1455}
1456
1457fn generalize_with_unifier(
1458    env: &TypeEnv,
1459    preds: Vec<Predicate>,
1460    typ: Type,
1461    unifier: &mut Unifier<'_>,
1462) -> Scheme {
1463    // This is `generalize`, but operating in the “imperative unifier world”.
1464    // It avoids constructing intermediate `Subst` maps while inference is
1465    // still mutating type variables.
1466    let preds: Vec<Predicate> = preds
1467        .into_iter()
1468        .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
1469        .collect();
1470    let typ = unifier.apply_type(&typ);
1471    let mut vars: Vec<TypeVar> = typ
1472        .ftv()
1473        .union(&preds.ftv())
1474        .copied()
1475        .collect::<HashSet<_>>()
1476        .difference(&env_ftv_with_unifier(env, unifier))
1477        .cloned()
1478        .map(|id| TypeVar::new(id, None))
1479        .collect();
1480    vars.sort_by_key(|v| v.id);
1481    Scheme::new(vars, preds, typ)
1482}
1483
1484fn monomorphic_scheme_with_unifier(
1485    preds: Vec<Predicate>,
1486    typ: Type,
1487    unifier: &mut Unifier<'_>,
1488) -> Scheme {
1489    let preds = dedup_preds(
1490        preds
1491            .into_iter()
1492            .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
1493            .collect(),
1494    );
1495    let typ = unifier.apply_type(&typ);
1496    Scheme::new(vec![], preds, typ)
1497}
1498
1499fn is_integral_literal_expr(expr: &Expr) -> bool {
1500    matches!(expr, Expr::Int(..) | Expr::Uint(..))
1501}
1502
1503/// Turn a monotype `typ` (plus constraints `preds`) into a polymorphic `Scheme`
1504/// by quantifying over the type variables not free in `env`.
1505pub fn generalize(env: &TypeEnv, preds: Vec<Predicate>, typ: Type) -> Scheme {
1506    let mut vars: Vec<TypeVar> = typ
1507        .ftv()
1508        .union(&preds.ftv())
1509        .copied()
1510        .collect::<HashSet<_>>()
1511        .difference(&env.ftv())
1512        .cloned()
1513        .map(|id| TypeVar::new(id, None))
1514        .collect();
1515    vars.sort_by_key(|v| v.id);
1516    Scheme::new(vars, preds, typ)
1517}
1518
1519pub fn instantiate(scheme: &Scheme, supply: &mut TypeVarSupply) -> (Vec<Predicate>, Type) {
1520    // Instantiate replaces all quantified variables with fresh unification
1521    // variables, preserving the original name as a debugging hint.
1522    let mut subst = Subst::new_sync();
1523    for v in &scheme.vars {
1524        subst = subst.insert(v.id, Type::var(supply.fresh(v.name.clone())));
1525    }
1526    (scheme.preds.apply(&subst), scheme.typ.apply(&subst))
1527}
1528
1529/// A named type parameter for an ADT (e.g. `a` in `List a`).
1530#[derive(Clone, Debug)]
1531pub struct AdtParam {
1532    pub name: Symbol,
1533    pub var: TypeVar,
1534}
1535
1536/// A single ADT variant with zero or more constructor arguments.
1537#[derive(Clone, Debug)]
1538pub struct AdtVariant {
1539    pub name: Symbol,
1540    pub args: Vec<Type>,
1541}
1542
1543/// A type declaration for an algebraic data type.
1544///
1545/// This only describes the *type* surface (params + variants). It does not
1546/// introduce any runtime values by itself. Runtime values are created by
1547/// injecting constructor schemes into the environment (see `inject_adt`).
1548#[derive(Clone, Debug)]
1549pub struct AdtDecl {
1550    pub name: Symbol,
1551    pub params: Vec<AdtParam>,
1552    pub variants: Vec<AdtVariant>,
1553}
1554
1555impl AdtDecl {
1556    pub fn new(name: &Symbol, param_names: &[Symbol], supply: &mut TypeVarSupply) -> Self {
1557        let params = param_names
1558            .iter()
1559            .map(|p| AdtParam {
1560                name: p.clone(),
1561                var: supply.fresh(Some(p.clone())),
1562            })
1563            .collect();
1564        Self {
1565            name: name.clone(),
1566            params,
1567            variants: Vec::new(),
1568        }
1569    }
1570
1571    pub fn param_type(&self, name: &Symbol) -> Option<Type> {
1572        self.params
1573            .iter()
1574            .find(|p| &p.name == name)
1575            .map(|p| Type::var(p.var.clone()))
1576    }
1577
1578    pub fn add_variant(&mut self, name: Symbol, args: Vec<Type>) {
1579        self.variants.push(AdtVariant { name, args });
1580    }
1581
1582    pub fn result_type(&self) -> Type {
1583        let mut ty = Type::con(&self.name, self.params.len());
1584        for param in &self.params {
1585            ty = Type::app(ty, Type::var(param.var.clone()));
1586        }
1587        ty
1588    }
1589
1590    /// Build constructor schemes of the form:
1591    /// `C :: a1 -> a2 -> ... -> T params`.
1592    pub fn constructor_schemes(&self) -> Vec<(Symbol, Scheme)> {
1593        let result_ty = self.result_type();
1594        let vars: Vec<TypeVar> = self.params.iter().map(|p| p.var.clone()).collect();
1595        let mut out = Vec::new();
1596        for variant in &self.variants {
1597            let mut typ = result_ty.clone();
1598            for arg in variant.args.iter().rev() {
1599                typ = Type::fun(arg.clone(), typ);
1600            }
1601            out.push((variant.name.clone(), Scheme::new(vars.clone(), vec![], typ)));
1602        }
1603        out
1604    }
1605}
1606
1607#[derive(Clone, Debug)]
1608pub struct Class {
1609    pub supers: Vec<Symbol>,
1610}
1611
1612impl Class {
1613    pub fn new(supers: Vec<Symbol>) -> Self {
1614        Self { supers }
1615    }
1616}
1617
1618#[derive(Clone, Debug)]
1619pub struct Instance {
1620    pub context: Vec<Predicate>,
1621    pub head: Predicate,
1622}
1623
1624impl Instance {
1625    pub fn new(context: Vec<Predicate>, head: Predicate) -> Self {
1626        Self { context, head }
1627    }
1628}
1629
1630#[derive(Default, Debug, Clone)]
1631pub struct ClassEnv {
1632    pub classes: HashMap<Symbol, Class>,
1633    pub instances: HashMap<Symbol, Vec<Instance>>,
1634}
1635
1636impl ClassEnv {
1637    pub fn new() -> Self {
1638        Self {
1639            classes: HashMap::new(),
1640            instances: HashMap::new(),
1641        }
1642    }
1643
1644    pub fn add_class(&mut self, name: Symbol, supers: Vec<Symbol>) {
1645        self.classes.insert(name, Class::new(supers));
1646    }
1647
1648    pub fn add_instance(&mut self, class: Symbol, inst: Instance) {
1649        self.instances.entry(class).or_default().push(inst);
1650    }
1651
1652    pub fn supers_of(&self, class: &Symbol) -> Vec<Symbol> {
1653        self.classes
1654            .get(class)
1655            .map(|c| c.supers.clone())
1656            .unwrap_or_default()
1657    }
1658}
1659
1660pub fn entails(
1661    class_env: &ClassEnv,
1662    given: &[Predicate],
1663    pred: &Predicate,
1664) -> Result<bool, TypeError> {
1665    // Expand given with superclasses.
1666    let mut closure: Vec<Predicate> = given.to_vec();
1667    let mut i = 0;
1668    while i < closure.len() {
1669        let p = closure[i].clone();
1670        for sup in class_env.supers_of(&p.class) {
1671            closure.push(Predicate::new(sup, p.typ.clone()));
1672        }
1673        i += 1;
1674    }
1675
1676    if closure
1677        .iter()
1678        .any(|p| p.class == pred.class && p.typ == pred.typ)
1679    {
1680        return Ok(true);
1681    }
1682
1683    if !class_env.classes.contains_key(&pred.class) {
1684        return Err(TypeError::UnknownClass(pred.class.clone()));
1685    }
1686
1687    if let Some(instances) = class_env.instances.get(&pred.class) {
1688        for inst in instances {
1689            if let Ok(s) = unify(&inst.head.typ, &pred.typ) {
1690                let ctx = inst.context.apply(&s);
1691                if ctx
1692                    .iter()
1693                    .all(|c| entails(class_env, &closure, c).unwrap_or(false))
1694                {
1695                    return Ok(true);
1696                }
1697            }
1698        }
1699    }
1700    Ok(false)
1701}
1702
1703#[derive(Default, Debug, Clone)]
1704pub struct TypeSystem {
1705    pub env: TypeEnv,
1706    pub classes: ClassEnv,
1707    pub adts: HashMap<Symbol, AdtDecl>,
1708    pub class_info: HashMap<Symbol, ClassInfo>,
1709    pub class_methods: HashMap<Symbol, ClassMethodInfo>,
1710    /// Names introduced by `declare fn` (forward declarations).
1711    ///
1712    /// These are placeholders in the type environment and must not block a later
1713    /// real definition (e.g. `fn foo = ...` or host/CLI injection).
1714    pub declared_values: HashSet<Symbol>,
1715    pub supply: TypeVarSupply,
1716    limits: TypeSystemLimits,
1717}
1718
1719/// Semantic information about a type class declaration, derived from Rex source.
1720///
1721/// Design notes (WARM):
1722/// - We keep this explicit and data-oriented: it makes review easy and keeps costs visible.
1723/// - Rex represents multi-parameter classes by encoding the parameters as a tuple in the
1724///   single `Predicate.typ` slot. For a unary class `C a` the predicate is `C a`. For a
1725///   binary class `C t a` the predicate is `C (t, a)`, etc.
1726/// - This keeps the runtime/type-inference machinery simple: instance matching is still
1727///   “unify the predicate types”, and no separate arity tracking is needed.
1728#[derive(Clone, Debug)]
1729pub struct ClassInfo {
1730    pub name: Symbol,
1731    pub params: Vec<Symbol>,
1732    pub supers: Vec<Symbol>,
1733    pub methods: BTreeMap<Symbol, Scheme>,
1734}
1735
1736#[derive(Clone, Debug)]
1737pub struct ClassMethodInfo {
1738    pub class: Symbol,
1739    pub scheme: Scheme,
1740}
1741
1742#[derive(Clone, Debug)]
1743pub struct PreparedInstanceDecl {
1744    pub span: Span,
1745    pub class: Symbol,
1746    pub head: Type,
1747    pub context: Vec<Predicate>,
1748}
1749
1750impl TypeSystem {
1751    pub fn new() -> Self {
1752        Self {
1753            env: TypeEnv::new(),
1754            classes: ClassEnv::new(),
1755            adts: HashMap::new(),
1756            class_info: HashMap::new(),
1757            class_methods: HashMap::new(),
1758            declared_values: HashSet::new(),
1759            supply: TypeVarSupply::new(),
1760            limits: TypeSystemLimits::default(),
1761        }
1762    }
1763
1764    pub fn fresh_type_var(&mut self, name: Option<Symbol>) -> TypeVar {
1765        self.supply.fresh(name)
1766    }
1767
1768    pub fn set_limits(&mut self, limits: TypeSystemLimits) {
1769        self.limits = limits;
1770    }
1771
1772    pub fn with_prelude() -> Result<Self, TypeError> {
1773        let mut ts = TypeSystem::new();
1774        prelude::build_prelude(&mut ts)?;
1775        Ok(ts)
1776    }
1777
1778    pub fn inject_decl(&mut self, decl: &Decl) -> Result<(), TypeError> {
1779        match decl {
1780            Decl::Type(ty) => self.inject_type_decl(ty),
1781            Decl::Class(class_decl) => self.inject_class_decl(class_decl),
1782            Decl::Instance(inst_decl) => {
1783                let _ = self.inject_instance_decl(inst_decl)?;
1784                Ok(())
1785            }
1786            Decl::Fn(fd) => self.inject_fn_decls(std::slice::from_ref(fd)),
1787            Decl::DeclareFn(fd) => self.inject_declare_fn_decl(fd),
1788            Decl::Import(..) => Ok(()),
1789        }
1790    }
1791
1792    pub fn inject_decls(&mut self, decls: &[Decl]) -> Result<(), TypeError> {
1793        let mut pending_fns: Vec<FnDecl> = Vec::new();
1794        for decl in decls {
1795            if let Decl::Fn(fd) = decl {
1796                pending_fns.push(fd.clone());
1797                continue;
1798            }
1799
1800            if !pending_fns.is_empty() {
1801                self.inject_fn_decls(&pending_fns)?;
1802                pending_fns.clear();
1803            }
1804
1805            self.inject_decl(decl)?;
1806        }
1807        if !pending_fns.is_empty() {
1808            self.inject_fn_decls(&pending_fns)?;
1809        }
1810        Ok(())
1811    }
1812
1813    pub fn add_value(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1814        let name = sym(name.as_ref());
1815        self.declared_values.remove(&name);
1816        self.env.extend(name, scheme);
1817    }
1818
1819    pub fn add_overload(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1820        let name = sym(name.as_ref());
1821        self.declared_values.remove(&name);
1822        self.env.extend_overload(name, scheme);
1823    }
1824
1825    pub fn inject_class(&mut self, name: impl AsRef<str>, supers: Vec<Symbol>) {
1826        self.classes.add_class(sym(name.as_ref()), supers);
1827    }
1828
1829    pub fn inject_instance(&mut self, class: impl AsRef<str>, inst: Instance) {
1830        self.classes.add_instance(sym(class.as_ref()), inst);
1831    }
1832
1833    pub fn inject_class_decl(&mut self, decl: &ClassDecl) -> Result<(), TypeError> {
1834        let span = decl.span;
1835        (|| {
1836            // Classes are global, and Rex does not support reopening/merging them.
1837            // Allowing that would be a long-term maintenance hazard: it creates
1838            // spooky-action-at-a-distance across modules and makes reviews harder.
1839            if self.class_info.contains_key(&decl.name)
1840                || self.classes.classes.contains_key(&decl.name)
1841            {
1842                return Err(TypeError::DuplicateClass(decl.name.clone()));
1843            }
1844            if decl.params.is_empty() {
1845                return Err(TypeError::InvalidClassArity {
1846                    class: decl.name.clone(),
1847                    got: decl.params.len(),
1848                });
1849            }
1850            let params = decl.params.clone();
1851
1852            // Register the superclass relationships in the class environment.
1853            //
1854            // We only accept `<= C param` style superclasses for now. Anything
1855            // fancier would require storing type-level relationships in `ClassEnv`,
1856            // which Rex does not currently model.
1857            let mut supers = Vec::with_capacity(decl.supers.len());
1858            if !decl.supers.is_empty() && params.len() != 1 {
1859                return Err(TypeError::UnsupportedExpr(
1860                    "multi-parameter classes cannot declare superclasses yet",
1861                ));
1862            }
1863            for sup in &decl.supers {
1864                let mut vars = HashMap::new();
1865                let param = params[0].clone();
1866                let param_tv = self.supply.fresh(Some(param.clone()));
1867                vars.insert(param, param_tv.clone());
1868                let sup_ty = type_from_annotation_expr_vars(
1869                    &self.adts,
1870                    &sup.typ,
1871                    &mut vars,
1872                    &mut self.supply,
1873                )?;
1874                if sup_ty != Type::var(param_tv) {
1875                    return Err(TypeError::UnsupportedExpr(
1876                        "superclass constraints must be of the form `<= C a`",
1877                    ));
1878                }
1879                supers.push(sup.class.to_dotted_symbol());
1880            }
1881
1882            self.classes.add_class(decl.name.clone(), supers.clone());
1883
1884            let mut methods = BTreeMap::new();
1885            for ClassMethodSig { name, typ } in &decl.methods {
1886                if self.env.lookup(name).is_some() || self.class_methods.contains_key(name) {
1887                    return Err(TypeError::DuplicateClassMethod(name.clone()));
1888                }
1889
1890                let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1891                let mut param_tvs: Vec<TypeVar> = Vec::with_capacity(params.len());
1892                for param in &params {
1893                    let tv = self.supply.fresh(Some(param.clone()));
1894                    vars.insert(param.clone(), tv.clone());
1895                    param_tvs.push(tv);
1896                }
1897
1898                let ty =
1899                    type_from_annotation_expr_vars(&self.adts, typ, &mut vars, &mut self.supply)?;
1900
1901                let mut scheme_vars: Vec<TypeVar> = vars.values().cloned().collect();
1902                scheme_vars.sort_by_key(|tv| tv.id);
1903                scheme_vars.dedup_by_key(|tv| tv.id);
1904
1905                let class_pred = Predicate {
1906                    class: decl.name.clone(),
1907                    typ: if param_tvs.len() == 1 {
1908                        Type::var(param_tvs[0].clone())
1909                    } else {
1910                        Type::tuple(param_tvs.into_iter().map(Type::var).collect())
1911                    },
1912                };
1913                let scheme = Scheme::new(scheme_vars, vec![class_pred], ty);
1914
1915                self.env.extend(name.clone(), scheme.clone());
1916                self.class_methods.insert(
1917                    name.clone(),
1918                    ClassMethodInfo {
1919                        class: decl.name.clone(),
1920                        scheme: scheme.clone(),
1921                    },
1922                );
1923                methods.insert(name.clone(), scheme);
1924            }
1925
1926            self.class_info.insert(
1927                decl.name.clone(),
1928                ClassInfo {
1929                    name: decl.name.clone(),
1930                    params,
1931                    supers,
1932                    methods,
1933                },
1934            );
1935            Ok(())
1936        })()
1937        .map_err(|err| with_span(&span, err))
1938    }
1939
1940    pub fn inject_instance_decl(
1941        &mut self,
1942        decl: &InstanceDecl,
1943    ) -> Result<PreparedInstanceDecl, TypeError> {
1944        let span = decl.span;
1945        (|| {
1946            let class = decl.class.clone();
1947            if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
1948                return Err(TypeError::UnknownClass(class));
1949            }
1950
1951            let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1952            let head = type_from_annotation_expr_vars(
1953                &self.adts,
1954                &decl.head,
1955                &mut vars,
1956                &mut self.supply,
1957            )?;
1958            let context = predicates_from_constraints(
1959                &self.adts,
1960                &decl.context,
1961                &mut vars,
1962                &mut self.supply,
1963            )?;
1964
1965            let inst = Instance::new(
1966                context.clone(),
1967                Predicate {
1968                    class: decl.class.clone(),
1969                    typ: head.clone(),
1970                },
1971            );
1972
1973            // Validate method list against the class declaration if present.
1974            if let Some(info) = self.class_info.get(&decl.class) {
1975                for method in &decl.methods {
1976                    if !info.methods.contains_key(&method.name) {
1977                        return Err(TypeError::UnknownInstanceMethod {
1978                            class: decl.class.clone(),
1979                            method: method.name.clone(),
1980                        });
1981                    }
1982                }
1983                for method_name in info.methods.keys() {
1984                    if !decl.methods.iter().any(|m| &m.name == method_name) {
1985                        return Err(TypeError::MissingInstanceMethod {
1986                            class: decl.class.clone(),
1987                            method: method_name.clone(),
1988                        });
1989                    }
1990                }
1991            }
1992
1993            self.classes.add_instance(decl.class.clone(), inst);
1994            Ok(PreparedInstanceDecl {
1995                span,
1996                class: decl.class.clone(),
1997                head,
1998                context,
1999            })
2000        })()
2001        .map_err(|err| with_span(&span, err))
2002    }
2003
2004    pub fn prepare_instance_decl(
2005        &mut self,
2006        decl: &InstanceDecl,
2007    ) -> Result<PreparedInstanceDecl, TypeError> {
2008        let span = decl.span;
2009        (|| {
2010            let class = decl.class.clone();
2011            if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
2012                return Err(TypeError::UnknownClass(class));
2013            }
2014
2015            let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
2016            let head = type_from_annotation_expr_vars(
2017                &self.adts,
2018                &decl.head,
2019                &mut vars,
2020                &mut self.supply,
2021            )?;
2022            let context = predicates_from_constraints(
2023                &self.adts,
2024                &decl.context,
2025                &mut vars,
2026                &mut self.supply,
2027            )?;
2028
2029            // Validate method list against the class declaration if present.
2030            if let Some(info) = self.class_info.get(&decl.class) {
2031                for method in &decl.methods {
2032                    if !info.methods.contains_key(&method.name) {
2033                        return Err(TypeError::UnknownInstanceMethod {
2034                            class: decl.class.clone(),
2035                            method: method.name.clone(),
2036                        });
2037                    }
2038                }
2039                for method_name in info.methods.keys() {
2040                    if !decl.methods.iter().any(|m| &m.name == method_name) {
2041                        return Err(TypeError::MissingInstanceMethod {
2042                            class: decl.class.clone(),
2043                            method: method_name.clone(),
2044                        });
2045                    }
2046                }
2047            }
2048
2049            Ok(PreparedInstanceDecl {
2050                span,
2051                class: decl.class.clone(),
2052                head,
2053                context,
2054            })
2055        })()
2056        .map_err(|err| with_span(&span, err))
2057    }
2058
2059    pub fn inject_fn_decl(&mut self, decl: &FnDecl) -> Result<(), TypeError> {
2060        self.inject_fn_decls(std::slice::from_ref(decl))
2061    }
2062
2063    pub fn inject_fn_decls(&mut self, decls: &[FnDecl]) -> Result<(), TypeError> {
2064        if decls.is_empty() {
2065            return Ok(());
2066        }
2067
2068        let saved_env = self.env.clone();
2069        let saved_declared = self.declared_values.clone();
2070
2071        let result: Result<(), TypeError> = (|| {
2072            #[derive(Clone)]
2073            struct FnInfo {
2074                decl: FnDecl,
2075                expected: Type,
2076                declared_preds: Vec<Predicate>,
2077                scheme: Scheme,
2078                ann_vars: HashMap<Symbol, TypeVar>,
2079            }
2080
2081            let mut infos: Vec<FnInfo> = Vec::with_capacity(decls.len());
2082            let mut seen_names = HashSet::new();
2083
2084            for decl in decls {
2085                let span = decl.span;
2086                let info = (|| {
2087                    let name = &decl.name.name;
2088                    if !seen_names.insert(name.clone()) {
2089                        return Err(TypeError::DuplicateValue(name.clone()));
2090                    }
2091
2092                    if self.env.lookup(name).is_some() {
2093                        if self.declared_values.remove(name) {
2094                            // A forward declaration should not block the real definition.
2095                            self.env.remove(name);
2096                        } else {
2097                            return Err(TypeError::DuplicateValue(name.clone()));
2098                        }
2099                    }
2100
2101                    let mut sig = decl.ret.clone();
2102                    for (_, ann) in decl.params.iter().rev() {
2103                        let span = Span::from_begin_end(ann.span().begin, sig.span().end);
2104                        sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
2105                    }
2106
2107                    let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
2108                    let expected = type_from_annotation_expr_vars(
2109                        &self.adts,
2110                        &sig,
2111                        &mut ann_vars,
2112                        &mut self.supply,
2113                    )?;
2114                    let declared_preds = predicates_from_constraints(
2115                        &self.adts,
2116                        &decl.constraints,
2117                        &mut ann_vars,
2118                        &mut self.supply,
2119                    )?;
2120
2121                    // Validate that declared constraints are well-formed.
2122                    let var_arities: HashMap<TypeVarId, usize> = ann_vars
2123                        .values()
2124                        .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
2125                        .collect();
2126                    for pred in &declared_preds {
2127                        let _ = entails(&self.classes, &[], pred)?;
2128                        let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
2129                        else {
2130                            continue;
2131                        };
2132                        let args: Vec<Type> = if expected_arities.len() == 1 {
2133                            vec![pred.typ.clone()]
2134                        } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2135                            if parts.len() != expected_arities.len() {
2136                                continue;
2137                            }
2138                            parts.clone()
2139                        } else {
2140                            continue;
2141                        };
2142
2143                        for (arg, expected_arity) in
2144                            args.iter().zip(expected_arities.iter().copied())
2145                        {
2146                            let got =
2147                                type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2148                                    TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2149                                    _ => None,
2150                                });
2151                            let Some(got) = got else {
2152                                continue;
2153                            };
2154                            if got != expected_arity {
2155                                return Err(TypeError::KindMismatch {
2156                                    class: pred.class.clone(),
2157                                    expected: expected_arity,
2158                                    got,
2159                                    typ: arg.to_string(),
2160                                });
2161                            }
2162                        }
2163                    }
2164
2165                    let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2166                    vars.sort_by_key(|v| v.id);
2167                    let scheme = Scheme::new(vars, declared_preds.clone(), expected.clone());
2168                    reject_ambiguous_scheme(&scheme)?;
2169
2170                    Ok(FnInfo {
2171                        decl: decl.clone(),
2172                        expected,
2173                        declared_preds,
2174                        scheme,
2175                        ann_vars,
2176                    })
2177                })();
2178
2179                infos.push(info.map_err(|err| with_span(&span, err))?);
2180            }
2181
2182            // Seed environment with all declared signatures first so fn bodies
2183            // can reference each other recursively (let-rec semantics).
2184            for info in &infos {
2185                self.env
2186                    .extend(info.decl.name.name.clone(), info.scheme.clone());
2187            }
2188
2189            for info in infos {
2190                let span = info.decl.span;
2191                let mut lam_body = info.decl.body.clone();
2192                let mut lam_end = lam_body.span().end;
2193                for (param, ann) in info.decl.params.iter().rev() {
2194                    let lam_constraints = Vec::new();
2195                    let span = Span::from_begin_end(param.span.begin, lam_end);
2196                    lam_body = Arc::new(Expr::Lam(
2197                        span,
2198                        Scope::new_sync(),
2199                        param.clone(),
2200                        Some(ann.clone()),
2201                        lam_constraints,
2202                        lam_body,
2203                    ));
2204                    lam_end = lam_body.span().end;
2205                }
2206
2207                let (typed, preds, inferred) = self.infer_typed(lam_body.as_ref())?;
2208                let s = unify(&inferred, &info.expected)?;
2209                let preds = preds.apply(&s);
2210                let inferred = inferred.apply(&s);
2211                let declared_preds = info.declared_preds.apply(&s);
2212                let expected = info.expected.apply(&s);
2213
2214                // Keep kind checks aligned with existing `inject_fn_decl` logic.
2215                let var_arities: HashMap<TypeVarId, usize> = info
2216                    .ann_vars
2217                    .values()
2218                    .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
2219                    .collect();
2220                for pred in &declared_preds {
2221                    let _ = entails(&self.classes, &[], pred)?;
2222                    let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
2223                    else {
2224                        continue;
2225                    };
2226                    let args: Vec<Type> = if expected_arities.len() == 1 {
2227                        vec![pred.typ.clone()]
2228                    } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2229                        if parts.len() != expected_arities.len() {
2230                            continue;
2231                        }
2232                        parts.clone()
2233                    } else {
2234                        continue;
2235                    };
2236
2237                    for (arg, expected_arity) in args.iter().zip(expected_arities.iter().copied()) {
2238                        let got = type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2239                            TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2240                            _ => None,
2241                        });
2242                        let Some(got) = got else {
2243                            continue;
2244                        };
2245                        if got != expected_arity {
2246                            return Err(with_span(
2247                                &span,
2248                                TypeError::KindMismatch {
2249                                    class: pred.class.clone(),
2250                                    expected: expected_arity,
2251                                    got,
2252                                    typ: arg.to_string(),
2253                                },
2254                            ));
2255                        }
2256                    }
2257                }
2258
2259                check_non_ground_predicates_declared(&self.classes, &declared_preds, &preds)
2260                    .map_err(|err| with_span(&span, err))?;
2261
2262                let _ = inferred;
2263                let _ = typed;
2264            }
2265
2266            Ok(())
2267        })();
2268
2269        if result.is_err() {
2270            self.env = saved_env;
2271            self.declared_values = saved_declared;
2272        }
2273        result
2274    }
2275
2276    pub fn inject_declare_fn_decl(&mut self, decl: &DeclareFnDecl) -> Result<(), TypeError> {
2277        let span = decl.span;
2278        (|| {
2279            // Build the declared signature type.
2280            let mut sig = decl.ret.clone();
2281            for (_, ann) in decl.params.iter().rev() {
2282                let span = Span::from_begin_end(ann.span().begin, sig.span().end);
2283                sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
2284            }
2285
2286            let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
2287            let expected =
2288                type_from_annotation_expr_vars(&self.adts, &sig, &mut ann_vars, &mut self.supply)?;
2289            let declared_preds = predicates_from_constraints(
2290                &self.adts,
2291                &decl.constraints,
2292                &mut ann_vars,
2293                &mut self.supply,
2294            )?;
2295
2296            let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2297            vars.sort_by_key(|v| v.id);
2298            let scheme = Scheme::new(vars, declared_preds, expected);
2299            reject_ambiguous_scheme(&scheme)?;
2300
2301            // Validate referenced classes exist (and are spelled correctly).
2302            for pred in &scheme.preds {
2303                let _ = entails(&self.classes, &[], pred)?;
2304            }
2305
2306            let name = &decl.name.name;
2307
2308            // If there is already a real definition (prelude/host/`fn`), treat
2309            // `declare fn` as documentation only and ignore it.
2310            if self.env.lookup(name).is_some() && !self.declared_values.contains(name) {
2311                return Ok(());
2312            }
2313
2314            if let Some(existing) = self.env.lookup(name) {
2315                if existing.iter().any(|s| scheme_compatible(s, &scheme)) {
2316                    return Ok(());
2317                }
2318                return Err(TypeError::DuplicateValue(decl.name.name.clone()));
2319            }
2320
2321            self.env.extend(decl.name.name.clone(), scheme);
2322            self.declared_values.insert(decl.name.name.clone());
2323            Ok(())
2324        })()
2325        .map_err(|err| with_span(&span, err))
2326    }
2327
2328    pub fn instantiate_class_method_for_head(
2329        &mut self,
2330        class: &Symbol,
2331        method: &Symbol,
2332        head: &Type,
2333    ) -> Result<Type, TypeError> {
2334        let info = self
2335            .class_info
2336            .get(class)
2337            .ok_or_else(|| TypeError::UnknownClass(class.clone()))?;
2338        let scheme = info
2339            .methods
2340            .get(method)
2341            .ok_or_else(|| TypeError::UnknownInstanceMethod {
2342                class: class.clone(),
2343                method: method.clone(),
2344            })?;
2345
2346        let (preds, typ) = instantiate(scheme, &mut self.supply);
2347        let class_pred =
2348            preds
2349                .iter()
2350                .find(|p| &p.class == class)
2351                .ok_or(TypeError::UnsupportedExpr(
2352                    "class method scheme missing class predicate",
2353                ))?;
2354        let s = unify(&class_pred.typ, head)?;
2355        Ok(typ.apply(&s))
2356    }
2357
2358    pub fn typecheck_instance_method(
2359        &mut self,
2360        prepared: &PreparedInstanceDecl,
2361        method: &InstanceMethodImpl,
2362    ) -> Result<TypedExpr, TypeError> {
2363        let expected =
2364            self.instantiate_class_method_for_head(&prepared.class, &method.name, &prepared.head)?;
2365        let (typed, preds, actual) = self.infer_typed(method.body.as_ref())?;
2366        let s = unify(&actual, &expected)?;
2367        let typed = typed.apply(&s);
2368        let preds = preds.apply(&s);
2369
2370        // The only legal “given” constraints inside an instance method are the
2371        // instance context (plus superclass closure, plus the instance head
2372        // itself). We do *not* allow instance
2373        // search for non-ground constraints here, because that would be unsound:
2374        // a type variable would unify with any concrete instance head.
2375        let mut given = prepared.context.clone();
2376
2377        // Allow recursive instance methods (e.g. `Eq (List a)` calling `(==)`
2378        // on the tail). This is dictionary recursion, not instance search.
2379        given.push(Predicate::new(
2380            prepared.class.clone(),
2381            prepared.head.clone(),
2382        ));
2383        let mut i = 0;
2384        while i < given.len() {
2385            let p = given[i].clone();
2386            for sup in self.classes.supers_of(&p.class) {
2387                given.push(Predicate::new(sup, p.typ.clone()));
2388            }
2389            i += 1;
2390        }
2391
2392        for pred in &preds {
2393            if pred.typ.ftv().is_empty() {
2394                if !entails(&self.classes, &given, pred)? {
2395                    return Err(TypeError::NoInstance(
2396                        pred.class.clone(),
2397                        pred.typ.to_string(),
2398                    ));
2399                }
2400            } else if !given
2401                .iter()
2402                .any(|p| p.class == pred.class && p.typ == pred.typ)
2403            {
2404                return Err(TypeError::MissingInstanceConstraint {
2405                    method: method.name.clone(),
2406                    class: pred.class.clone(),
2407                    typ: pred.typ.to_string(),
2408                });
2409            }
2410        }
2411
2412        Ok(typed)
2413    }
2414
2415    /// Register constructor schemes for an ADT in the type environment.
2416    /// This makes constructors (e.g. `Some`, `None`, `Ok`, `Err`) available
2417    /// to the type checker as normal values.
2418    pub fn inject_adt(&mut self, adt: &AdtDecl) {
2419        self.adts.insert(adt.name.clone(), adt.clone());
2420        for (name, scheme) in adt.constructor_schemes() {
2421            self.register_value_scheme(&name, scheme);
2422        }
2423    }
2424
2425    pub fn adt_from_decl(&mut self, decl: &TypeDecl) -> Result<AdtDecl, TypeError> {
2426        let mut adt = AdtDecl::new(&decl.name, &decl.params, &mut self.supply);
2427        let mut param_map: HashMap<Symbol, TypeVar> = HashMap::new();
2428        for param in &adt.params {
2429            param_map.insert(param.name.clone(), param.var.clone());
2430        }
2431
2432        for variant in &decl.variants {
2433            let mut args = Vec::new();
2434            for arg in &variant.args {
2435                let ty = self.type_from_expr(decl, &param_map, arg)?;
2436                args.push(ty);
2437            }
2438            adt.add_variant(variant.name.clone(), args);
2439        }
2440        Ok(adt)
2441    }
2442
2443    pub fn inject_type_decl(&mut self, decl: &TypeDecl) -> Result<(), TypeError> {
2444        if BuiltinTypeId::from_symbol(&decl.name).is_some() {
2445            return Err(TypeError::ReservedTypeName(decl.name.clone()));
2446        }
2447        let adt = self.adt_from_decl(decl)?;
2448        self.inject_adt(&adt);
2449        Ok(())
2450    }
2451
2452    fn type_from_expr(
2453        &mut self,
2454        decl: &TypeDecl,
2455        params: &HashMap<Symbol, TypeVar>,
2456        expr: &TypeExpr,
2457    ) -> Result<Type, TypeError> {
2458        let span = *expr.span();
2459        let res = (|| match expr {
2460            TypeExpr::Name(_, name) => {
2461                let name_sym = name.to_dotted_symbol();
2462                if let Some(tv) = params.get(&name_sym) {
2463                    Ok(Type::var(tv.clone()))
2464                } else {
2465                    let name = normalize_type_name(&name_sym);
2466                    if let Some(arity) = self.type_arity(decl, &name) {
2467                        Ok(Type::con(name, arity))
2468                    } else {
2469                        Err(TypeError::UnknownTypeName(name))
2470                    }
2471                }
2472            }
2473            TypeExpr::App(_, fun, arg) => {
2474                let fty = self.type_from_expr(decl, params, fun)?;
2475                let aty = self.type_from_expr(decl, params, arg)?;
2476                Ok(type_app_with_result_syntax(fty, aty))
2477            }
2478            TypeExpr::Fun(_, arg, ret) => {
2479                let arg_ty = self.type_from_expr(decl, params, arg)?;
2480                let ret_ty = self.type_from_expr(decl, params, ret)?;
2481                Ok(Type::fun(arg_ty, ret_ty))
2482            }
2483            TypeExpr::Tuple(_, elems) => {
2484                let mut out = Vec::new();
2485                for elem in elems {
2486                    out.push(self.type_from_expr(decl, params, elem)?);
2487                }
2488                Ok(Type::tuple(out))
2489            }
2490            TypeExpr::Record(_, fields) => {
2491                let mut out = Vec::new();
2492                for (name, ty) in fields {
2493                    out.push((name.clone(), self.type_from_expr(decl, params, ty)?));
2494                }
2495                Ok(Type::record(out))
2496            }
2497        })();
2498        res.map_err(|err| with_span(&span, err))
2499    }
2500
2501    fn type_arity(&self, decl: &TypeDecl, name: &Symbol) -> Option<usize> {
2502        if &decl.name == name {
2503            return Some(decl.params.len());
2504        }
2505        if let Some(adt) = self.adts.get(name) {
2506            return Some(adt.params.len());
2507        }
2508        BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2509    }
2510
2511    fn register_value_scheme(&mut self, name: &Symbol, scheme: Scheme) {
2512        match self.env.lookup(name) {
2513            None => self.env.extend(name.clone(), scheme),
2514            Some(existing) => {
2515                if existing.iter().any(|s| unify(&s.typ, &scheme.typ).is_ok()) {
2516                    return;
2517                }
2518                self.env.extend_overload(name.clone(), scheme);
2519            }
2520        }
2521    }
2522
2523    pub fn infer_typed(
2524        &mut self,
2525        expr: &Expr,
2526    ) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
2527        self.infer_typed_inner(expr)
2528    }
2529
2530    pub fn infer_typed_with_gas(
2531        &mut self,
2532        expr: &Expr,
2533        gas: &mut GasMeter,
2534    ) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
2535        let known = KnownVariants::new();
2536        let mut unifier = Unifier::with_gas(gas, self.limits.max_infer_depth);
2537        let (preds, t, typed) = infer_expr(
2538            &mut unifier,
2539            &mut self.supply,
2540            &self.env,
2541            &self.adts,
2542            &known,
2543            expr,
2544        )
2545        .map_err(|err| with_span(expr.span(), err))?;
2546        let subst = unifier.into_subst();
2547        let mut typed = typed.apply(&subst);
2548        let mut preds = dedup_preds(preds.apply(&subst));
2549        let mut t = t.apply(&subst);
2550        let improve = improve_indexable(&preds)?;
2551        if !subst_is_empty(&improve) {
2552            typed = typed.apply(&improve);
2553            preds = dedup_preds(preds.apply(&improve));
2554            t = t.apply(&improve);
2555        }
2556        self.check_predicate_kinds(&preds)?;
2557        Ok((typed, preds, t))
2558    }
2559
2560    fn infer_typed_inner(
2561        &mut self,
2562        expr: &Expr,
2563    ) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
2564        let known = KnownVariants::new();
2565        let mut unifier = Unifier::new(self.limits.max_infer_depth);
2566        let (preds, t, typed) = infer_expr(
2567            &mut unifier,
2568            &mut self.supply,
2569            &self.env,
2570            &self.adts,
2571            &known,
2572            expr,
2573        )
2574        .map_err(|err| with_span(expr.span(), err))?;
2575        let subst = unifier.into_subst();
2576        let mut typed = typed.apply(&subst);
2577        let mut preds = dedup_preds(preds.apply(&subst));
2578        let mut t = t.apply(&subst);
2579        let improve = improve_indexable(&preds)?;
2580        if !subst_is_empty(&improve) {
2581            typed = typed.apply(&improve);
2582            preds = dedup_preds(preds.apply(&improve));
2583            t = t.apply(&improve);
2584        }
2585        self.check_predicate_kinds(&preds)?;
2586        Ok((typed, preds, t))
2587    }
2588
2589    pub fn infer(&mut self, expr: &Expr) -> Result<(Vec<Predicate>, Type), TypeError> {
2590        self.infer_inner(expr)
2591    }
2592
2593    pub fn infer_with_gas(
2594        &mut self,
2595        expr: &Expr,
2596        gas: &mut GasMeter,
2597    ) -> Result<(Vec<Predicate>, Type), TypeError> {
2598        let known = KnownVariants::new();
2599        let mut unifier = Unifier::with_gas(gas, self.limits.max_infer_depth);
2600        let (preds, t) = infer_expr_type(
2601            &mut unifier,
2602            &mut self.supply,
2603            &self.env,
2604            &self.adts,
2605            &known,
2606            expr,
2607        )
2608        .map_err(|err| with_span(expr.span(), err))?;
2609        let subst = unifier.into_subst();
2610        let preds = dedup_preds(preds.apply(&subst));
2611        let t = t.apply(&subst);
2612        self.check_predicate_kinds(&preds)?;
2613        finalize_infer_for_public_api(preds, t)
2614    }
2615
2616    fn infer_inner(&mut self, expr: &Expr) -> Result<(Vec<Predicate>, Type), TypeError> {
2617        let known = KnownVariants::new();
2618        let mut unifier = Unifier::new(self.limits.max_infer_depth);
2619        let (preds, t) = infer_expr_type(
2620            &mut unifier,
2621            &mut self.supply,
2622            &self.env,
2623            &self.adts,
2624            &known,
2625            expr,
2626        )
2627        .map_err(|err| with_span(expr.span(), err))?;
2628        let subst = unifier.into_subst();
2629        let mut preds = dedup_preds(preds.apply(&subst));
2630        let mut t = t.apply(&subst);
2631        let improve = improve_indexable(&preds)?;
2632        if !subst_is_empty(&improve) {
2633            preds = dedup_preds(preds.apply(&improve));
2634            t = t.apply(&improve);
2635        }
2636        self.check_predicate_kinds(&preds)?;
2637        finalize_infer_for_public_api(preds, t)
2638    }
2639
2640    fn expected_class_param_arities(&self, class: &Symbol) -> Option<Vec<usize>> {
2641        let info = self.class_info.get(class)?;
2642        let mut out = vec![0usize; info.params.len()];
2643        for scheme in info.methods.values() {
2644            for (idx, param) in info.params.iter().enumerate() {
2645                let Some(tv) = scheme.vars.iter().find(|v| v.name.as_ref() == Some(param)) else {
2646                    continue;
2647                };
2648                out[idx] = out[idx].max(max_head_app_arity_for_var(&scheme.typ, tv.id));
2649            }
2650        }
2651        Some(out)
2652    }
2653
2654    fn check_predicate_kind(&self, pred: &Predicate) -> Result<(), TypeError> {
2655        let Some(expected) = self.expected_class_param_arities(&pred.class) else {
2656            // Host-injected classes (via Rust API) won't have `class_info`.
2657            return Ok(());
2658        };
2659
2660        let args: Vec<Type> = if expected.len() == 1 {
2661            vec![pred.typ.clone()]
2662        } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2663            if parts.len() != expected.len() {
2664                return Ok(());
2665            }
2666            parts.clone()
2667        } else {
2668            return Ok(());
2669        };
2670
2671        for (arg, expected_arity) in args.iter().zip(expected.iter().copied()) {
2672            let Some(got) = type_term_remaining_arity(arg) else {
2673                // If we can't determine the arity (e.g. a bare type var), skip:
2674                // call sites may fix it up, and Rex does not currently do full
2675                // kind inference.
2676                continue;
2677            };
2678            if got != expected_arity {
2679                return Err(TypeError::KindMismatch {
2680                    class: pred.class.clone(),
2681                    expected: expected_arity,
2682                    got,
2683                    typ: arg.to_string(),
2684                });
2685            }
2686        }
2687        Ok(())
2688    }
2689
2690    fn check_predicate_kinds(&self, preds: &[Predicate]) -> Result<(), TypeError> {
2691        for pred in preds {
2692            self.check_predicate_kind(pred)?;
2693        }
2694        Ok(())
2695    }
2696}
2697
2698fn improve_indexable(preds: &[Predicate]) -> Result<Subst, TypeError> {
2699    let mut subst = Subst::new_sync();
2700    loop {
2701        let mut changed = false;
2702        for pred in preds {
2703            let pred = pred.apply(&subst);
2704            if pred.class.as_ref() != "Indexable" {
2705                continue;
2706            }
2707            let TypeKind::Tuple(parts) = pred.typ.as_ref() else {
2708                continue;
2709            };
2710            if parts.len() != 2 {
2711                continue;
2712            }
2713            let container = parts[0].clone();
2714            let elem = parts[1].clone();
2715            let s = indexable_elem_subst(&container, &elem)?;
2716            if !subst_is_empty(&s) {
2717                subst = compose_subst(s, subst);
2718                changed = true;
2719            }
2720        }
2721        if !changed {
2722            break;
2723        }
2724    }
2725    Ok(subst)
2726}
2727
2728fn indexable_elem_subst(container: &Type, elem: &Type) -> Result<Subst, TypeError> {
2729    match container.as_ref() {
2730        TypeKind::App(head, arg) => match head.as_ref() {
2731            TypeKind::Con(tc)
2732                if matches!(
2733                    tc.builtin_id,
2734                    Some(BuiltinTypeId::List | BuiltinTypeId::Array)
2735                ) =>
2736            {
2737                unify(elem, arg)
2738            }
2739            _ => Ok(Subst::new_sync()),
2740        },
2741        TypeKind::Tuple(elems) => {
2742            if elems.is_empty() {
2743                return Ok(Subst::new_sync());
2744            }
2745            let mut subst = Subst::new_sync();
2746            let mut cur = elems[0].clone();
2747            for ty in elems.iter().skip(1) {
2748                let s_next = unify(&cur.apply(&subst), &ty.apply(&subst))?;
2749                subst = compose_subst(s_next, subst);
2750                cur = cur.apply(&subst);
2751            }
2752            let elem = elem.apply(&subst);
2753            let s_elem = unify(&elem, &cur.apply(&subst))?;
2754            Ok(compose_subst(s_elem, subst))
2755        }
2756        _ => Ok(Subst::new_sync()),
2757    }
2758}
2759
2760fn type_from_annotation_expr(
2761    adts: &HashMap<Symbol, AdtDecl>,
2762    expr: &TypeExpr,
2763) -> Result<Type, TypeError> {
2764    let span = *expr.span();
2765    let res = (|| match expr {
2766        TypeExpr::Name(_, name) => {
2767            let name = normalize_type_name(&name.to_dotted_symbol());
2768            match annotation_type_arity(adts, &name) {
2769                Some(arity) => Ok(Type::con(name, arity)),
2770                None => Err(TypeError::UnknownTypeName(name)),
2771            }
2772        }
2773        TypeExpr::App(_, fun, arg) => {
2774            let fty = type_from_annotation_expr(adts, fun)?;
2775            let aty = type_from_annotation_expr(adts, arg)?;
2776            Ok(type_app_with_result_syntax(fty, aty))
2777        }
2778        TypeExpr::Fun(_, arg, ret) => {
2779            let arg_ty = type_from_annotation_expr(adts, arg)?;
2780            let ret_ty = type_from_annotation_expr(adts, ret)?;
2781            Ok(Type::fun(arg_ty, ret_ty))
2782        }
2783        TypeExpr::Tuple(_, elems) => {
2784            let mut out = Vec::new();
2785            for elem in elems {
2786                out.push(type_from_annotation_expr(adts, elem)?);
2787            }
2788            Ok(Type::tuple(out))
2789        }
2790        TypeExpr::Record(_, fields) => {
2791            let mut out = Vec::new();
2792            for (name, ty) in fields {
2793                out.push((name.clone(), type_from_annotation_expr(adts, ty)?));
2794            }
2795            Ok(Type::record(out))
2796        }
2797    })();
2798    res.map_err(|err| with_span(&span, err))
2799}
2800
2801fn type_from_annotation_expr_vars(
2802    adts: &HashMap<Symbol, AdtDecl>,
2803    expr: &TypeExpr,
2804    vars: &mut HashMap<Symbol, TypeVar>,
2805    supply: &mut TypeVarSupply,
2806) -> Result<Type, TypeError> {
2807    let span = *expr.span();
2808    let res = (|| match expr {
2809        TypeExpr::Name(_, name) => {
2810            let name = normalize_type_name(&name.to_dotted_symbol());
2811            if let Some(arity) = annotation_type_arity(adts, &name) {
2812                Ok(Type::con(name, arity))
2813            } else if let Some(tv) = vars.get(&name) {
2814                Ok(Type::var(tv.clone()))
2815            } else {
2816                let is_upper = name
2817                    .chars()
2818                    .next()
2819                    .map(|c| c.is_uppercase())
2820                    .unwrap_or(false);
2821                if is_upper {
2822                    return Err(TypeError::UnknownTypeName(name));
2823                }
2824                let tv = supply.fresh(Some(name.clone()));
2825                vars.insert(name.clone(), tv.clone());
2826                Ok(Type::var(tv))
2827            }
2828        }
2829        TypeExpr::App(_, fun, arg) => {
2830            let fty = type_from_annotation_expr_vars(adts, fun, vars, supply)?;
2831            let aty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2832            Ok(type_app_with_result_syntax(fty, aty))
2833        }
2834        TypeExpr::Fun(_, arg, ret) => {
2835            let arg_ty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2836            let ret_ty = type_from_annotation_expr_vars(adts, ret, vars, supply)?;
2837            Ok(Type::fun(arg_ty, ret_ty))
2838        }
2839        TypeExpr::Tuple(_, elems) => {
2840            let mut out = Vec::new();
2841            for elem in elems {
2842                out.push(type_from_annotation_expr_vars(adts, elem, vars, supply)?);
2843            }
2844            Ok(Type::tuple(out))
2845        }
2846        TypeExpr::Record(_, fields) => {
2847            let mut out = Vec::new();
2848            for (name, ty) in fields {
2849                out.push((
2850                    name.clone(),
2851                    type_from_annotation_expr_vars(adts, ty, vars, supply)?,
2852                ));
2853            }
2854            Ok(Type::record(out))
2855        }
2856    })();
2857    res.map_err(|err| with_span(&span, err))
2858}
2859
2860fn annotation_type_arity(adts: &HashMap<Symbol, AdtDecl>, name: &Symbol) -> Option<usize> {
2861    if let Some(adt) = adts.get(name) {
2862        return Some(adt.params.len());
2863    }
2864    BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2865}
2866
2867fn normalize_type_name(name: &Symbol) -> Symbol {
2868    if name.as_ref() == "str" {
2869        BuiltinTypeId::String.as_symbol()
2870    } else {
2871        name.clone()
2872    }
2873}
2874
2875fn type_app_with_result_syntax(fun: Type, arg: Type) -> Type {
2876    // Support Rust-style `Result ok err` syntax while keeping the internal
2877    // representation as `Result err ok` (so `Result err` remains the 1-argument
2878    // type constructor used for HKTs).
2879    if let TypeKind::App(head, ok) = fun.as_ref()
2880        && matches!(
2881            head.as_ref(),
2882            TypeKind::Con(c)
2883                if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
2884        )
2885    {
2886        return Type::app(Type::app(head.clone(), arg), ok.clone());
2887    }
2888    Type::app(fun, arg)
2889}
2890
2891type LambdaChain<'a> = (
2892    Vec<(Symbol, Option<TypeExpr>)>,
2893    Vec<TypeConstraint>,
2894    &'a Expr,
2895);
2896
2897fn collect_lambda_chain<'a>(expr: &'a Expr) -> LambdaChain<'a> {
2898    let mut params = Vec::new();
2899    let mut constraints = Vec::new();
2900    let mut cur = expr;
2901    let mut seen_constraints = false;
2902    while let Expr::Lam(_, _scope, param, ann, lam_constraints, body) = cur {
2903        if !lam_constraints.is_empty() {
2904            if seen_constraints {
2905                break;
2906            }
2907            constraints = lam_constraints.clone();
2908            seen_constraints = true;
2909        }
2910        params.push((param.name.clone(), ann.clone()));
2911        cur = body.as_ref();
2912    }
2913    (params, constraints, cur)
2914}
2915
2916fn predicates_from_constraints(
2917    adts: &HashMap<Symbol, AdtDecl>,
2918    constraints: &[TypeConstraint],
2919    vars: &mut HashMap<Symbol, TypeVar>,
2920    supply: &mut TypeVarSupply,
2921) -> Result<Vec<Predicate>, TypeError> {
2922    let mut out = Vec::with_capacity(constraints.len());
2923    for constraint in constraints {
2924        let ty = type_from_annotation_expr_vars(adts, &constraint.typ, vars, supply)?;
2925        out.push(Predicate::new(constraint.class.as_ref(), ty));
2926    }
2927    Ok(out)
2928}
2929
2930fn collect_app_chain(expr: &Expr) -> (&Expr, Vec<&Expr>) {
2931    let mut args = Vec::new();
2932    let mut cur = expr;
2933    while let Expr::App(_, f, x) = cur {
2934        args.push(x.as_ref());
2935        cur = f.as_ref();
2936    }
2937    args.reverse();
2938    (cur, args)
2939}
2940
2941fn narrow_overload_candidates(candidates: &[Type], arg_ty: &Type) -> Vec<Type> {
2942    let mut out = Vec::new();
2943    for candidate in candidates {
2944        let Some((params, ret)) = decompose_fun(candidate, 1) else {
2945            continue;
2946        };
2947        let param = &params[0];
2948        if let Ok(s) = unify(param, arg_ty) {
2949            out.push(ret.apply(&s));
2950        }
2951    }
2952    out
2953}
2954
2955fn unary_app_arg(typ: &Type, ctor_name: &str) -> Option<Type> {
2956    let TypeKind::App(head, arg) = typ.as_ref() else {
2957        return None;
2958    };
2959    let TypeKind::Con(tc) = head.as_ref() else {
2960        return None;
2961    };
2962    (tc.name.as_ref() == ctor_name && tc.arity == 1).then(|| arg.clone())
2963}
2964
2965fn infer_app_arg_type(
2966    unifier: &mut Unifier<'_>,
2967    supply: &mut TypeVarSupply,
2968    env: &TypeEnv,
2969    adts: &HashMap<Symbol, AdtDecl>,
2970    known: &KnownVariants,
2971    arg_hint: Option<Type>,
2972    arg: &Expr,
2973) -> Result<(Vec<Predicate>, Type), TypeError> {
2974    match (arg_hint, arg) {
2975        (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
2976            infer_record_update_type_with_hint(
2977                unifier,
2978                supply,
2979                env,
2980                adts,
2981                known,
2982                base.as_ref(),
2983                updates,
2984                &arg_hint,
2985            )
2986        }
2987        (Some(arg_hint), Expr::Dict(_, kvs))
2988            if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
2989        {
2990            let TypeKind::Record(fields) = arg_hint.as_ref() else {
2991                unreachable!("guarded by matches!")
2992            };
2993            let expected: HashMap<_, _> =
2994                fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
2995            let mut seen = HashSet::new();
2996            let mut preds = Vec::new();
2997            for (k, v) in kvs {
2998                let expected_ty = expected
2999                    .get(k)
3000                    .ok_or_else(|| TypeError::UnknownField {
3001                        field: k.clone(),
3002                        typ: Type::record(fields.clone()).to_string(),
3003                    })?
3004                    .clone();
3005                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
3006                unifier.unify(&t1, &expected_ty)?;
3007                preds.extend(p1);
3008                seen.insert(k.clone());
3009            }
3010            for key in expected.keys() {
3011                if !seen.contains(key.as_ref()) {
3012                    return Err(TypeError::UnknownField {
3013                        field: key.clone(),
3014                        typ: Type::record(fields.clone()).to_string(),
3015                    });
3016                }
3017            }
3018            let record_ty = Type::record(
3019                fields
3020                    .iter()
3021                    .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
3022                    .collect(),
3023            );
3024            Ok((preds, record_ty))
3025        }
3026        _ => infer_expr_type(unifier, supply, env, adts, known, arg),
3027    }
3028}
3029
3030fn infer_app_arg_typed(
3031    unifier: &mut Unifier<'_>,
3032    supply: &mut TypeVarSupply,
3033    env: &TypeEnv,
3034    adts: &HashMap<Symbol, AdtDecl>,
3035    known: &KnownVariants,
3036    arg_hint: Option<Type>,
3037    arg: &Expr,
3038) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
3039    match (arg_hint, arg) {
3040        (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
3041            infer_record_update_typed_with_hint(
3042                unifier,
3043                supply,
3044                env,
3045                adts,
3046                known,
3047                base.as_ref(),
3048                updates,
3049                &arg_hint,
3050            )
3051        }
3052        (Some(arg_hint), Expr::Dict(_, kvs))
3053            if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
3054        {
3055            let TypeKind::Record(fields) = arg_hint.as_ref() else {
3056                unreachable!("guarded by matches!")
3057            };
3058            let mut preds = Vec::new();
3059            let mut typed_kvs = BTreeMap::new();
3060            let expected: HashMap<_, _> =
3061                fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
3062            for (k, v) in kvs {
3063                let expected_ty = expected
3064                    .get(k)
3065                    .ok_or_else(|| TypeError::UnknownField {
3066                        field: k.clone(),
3067                        typ: Type::record(fields.clone()).to_string(),
3068                    })?
3069                    .clone();
3070                let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
3071                unifier.unify(&t1, &expected_ty)?;
3072                preds.extend(p1);
3073                typed_kvs.insert(k.clone(), typed_v);
3074            }
3075            for key in expected.keys() {
3076                if !typed_kvs.contains_key(key.as_ref()) {
3077                    return Err(TypeError::UnknownField {
3078                        field: key.clone(),
3079                        typ: Type::record(fields.clone()).to_string(),
3080                    });
3081                }
3082            }
3083            let record_ty = Type::record(
3084                fields
3085                    .iter()
3086                    .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
3087                    .collect(),
3088            );
3089            let typed = TypedExpr::new(record_ty.clone(), TypedExprKind::Dict(typed_kvs));
3090            Ok((preds, record_ty, typed))
3091        }
3092        _ => infer_expr(unifier, supply, env, adts, known, arg),
3093    }
3094}
3095
3096#[allow(clippy::too_many_arguments)]
3097fn infer_record_update_type_with_hint(
3098    unifier: &mut Unifier<'_>,
3099    supply: &mut TypeVarSupply,
3100    env: &TypeEnv,
3101    adts: &HashMap<Symbol, AdtDecl>,
3102    known: &KnownVariants,
3103    base: &Expr,
3104    updates: &BTreeMap<Symbol, Arc<Expr>>,
3105    hint_ty: &Type,
3106) -> Result<(Vec<Predicate>, Type), TypeError> {
3107    let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
3108    unifier.unify(&t_base, hint_ty)?;
3109    let base_ty = unifier.apply_type(&t_base);
3110    let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
3111    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
3112    let (result_ty, fields) = resolve_record_update(
3113        unifier,
3114        supply,
3115        adts,
3116        &base_ty,
3117        known_variant,
3118        &update_fields,
3119    )?;
3120    let expected: HashMap<_, _> = fields.into_iter().collect();
3121
3122    let mut preds = p_base;
3123    for (k, v) in updates {
3124        let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
3125            field: k.clone(),
3126            typ: result_ty.to_string(),
3127        })?;
3128        let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
3129        unifier.unify(&t1, expected_ty)?;
3130        preds.extend(p1);
3131    }
3132    Ok((preds, result_ty))
3133}
3134
3135#[allow(clippy::too_many_arguments)]
3136fn infer_record_update_typed_with_hint(
3137    unifier: &mut Unifier<'_>,
3138    supply: &mut TypeVarSupply,
3139    env: &TypeEnv,
3140    adts: &HashMap<Symbol, AdtDecl>,
3141    known: &KnownVariants,
3142    base: &Expr,
3143    updates: &BTreeMap<Symbol, Arc<Expr>>,
3144    hint_ty: &Type,
3145) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
3146    let (p_base, t_base, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
3147    unifier.unify(&t_base, hint_ty)?;
3148    let base_ty = unifier.apply_type(&t_base);
3149    let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
3150    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
3151    let (result_ty, fields) = resolve_record_update(
3152        unifier,
3153        supply,
3154        adts,
3155        &base_ty,
3156        known_variant,
3157        &update_fields,
3158    )?;
3159    let expected: HashMap<_, _> = fields.into_iter().collect();
3160
3161    let mut preds = p_base;
3162    let mut typed_updates = BTreeMap::new();
3163    for (k, v) in updates {
3164        let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
3165            field: k.clone(),
3166            typ: result_ty.to_string(),
3167        })?;
3168        let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
3169        unifier.unify(&t1, expected_ty)?;
3170        preds.extend(p1);
3171        typed_updates.insert(k.clone(), typed_v);
3172    }
3173
3174    let typed = TypedExpr::new(
3175        result_ty.clone(),
3176        TypedExprKind::RecordUpdate {
3177            base: Box::new(typed_base),
3178            updates: typed_updates,
3179        },
3180    );
3181    Ok((preds, result_ty, typed))
3182}
3183
3184fn infer_expr_type(
3185    unifier: &mut Unifier<'_>,
3186    supply: &mut TypeVarSupply,
3187    env: &TypeEnv,
3188    adts: &HashMap<Symbol, AdtDecl>,
3189    known: &KnownVariants,
3190    expr: &Expr,
3191) -> Result<(Vec<Predicate>, Type), TypeError> {
3192    let span = *expr.span();
3193    let res = unifier.with_infer_depth(span, |unifier| {
3194        infer_expr_type_inner(unifier, supply, env, adts, known, expr)
3195    });
3196    res.map_err(|err| with_span(&span, err))
3197}
3198
3199fn infer_expr_type_inner(
3200    unifier: &mut Unifier<'_>,
3201    supply: &mut TypeVarSupply,
3202    env: &TypeEnv,
3203    adts: &HashMap<Symbol, AdtDecl>,
3204    known: &KnownVariants,
3205    expr: &Expr,
3206) -> Result<(Vec<Predicate>, Type), TypeError> {
3207    unifier.charge_infer_node()?;
3208    match expr {
3209        Expr::Bool(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Bool))),
3210        Expr::Uint(_, _) => {
3211            let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
3212            Ok((vec![Predicate::new("Integral", lit_ty.clone())], lit_ty))
3213        }
3214        Expr::Int(_, _) => {
3215            let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
3216            Ok((
3217                vec![
3218                    Predicate::new("Integral", lit_ty.clone()),
3219                    Predicate::new("AdditiveGroup", lit_ty.clone()),
3220                ],
3221                lit_ty,
3222            ))
3223        }
3224        Expr::Float(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::F32))),
3225        Expr::String(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::String))),
3226        Expr::Uuid(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Uuid))),
3227        Expr::DateTime(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::DateTime))),
3228        Expr::Hole(_) => {
3229            let t = Type::var(supply.fresh(Some(sym("hole"))));
3230            Ok((vec![], t))
3231        }
3232        Expr::Var(var) => {
3233            let schemes = env
3234                .lookup(&var.name)
3235                .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
3236            if schemes.len() == 1 {
3237                let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
3238                let (preds, t) = instantiate(&scheme, supply);
3239                Ok((preds, t))
3240            } else {
3241                for scheme in schemes {
3242                    if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
3243                        return Err(TypeError::AmbiguousOverload(var.name.clone()));
3244                    }
3245                }
3246                let t = Type::var(supply.fresh(Some(var.name.clone())));
3247                Ok((vec![], t))
3248            }
3249        }
3250        Expr::Lam(..) => {
3251            let (params, constraints, body) = collect_lambda_chain(expr);
3252            let mut ann_vars = HashMap::new();
3253            let mut param_tys = Vec::with_capacity(params.len());
3254            for (name, ann) in &params {
3255                let param_ty = match ann {
3256                    Some(ann) => type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?,
3257                    None => Type::var(supply.fresh(Some(name.clone()))),
3258                };
3259                param_tys.push((name.clone(), param_ty));
3260            }
3261
3262            let mut env1 = env.clone();
3263            let mut known_body = known.clone();
3264            for (name, param_ty) in &param_tys {
3265                env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
3266                known_body.remove(name);
3267            }
3268
3269            let (mut preds, body_ty) =
3270                infer_expr_type(unifier, supply, &env1, adts, &known_body, body)?;
3271            let constraint_preds =
3272                predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
3273            preds.extend(constraint_preds);
3274
3275            let mut fun_ty = unifier.apply_type(&body_ty);
3276            for (_, param_ty) in param_tys.iter().rev() {
3277                fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
3278            }
3279            Ok((preds, fun_ty))
3280        }
3281        Expr::App(..) => {
3282            let (head, args) = collect_app_chain(expr);
3283            let (mut preds, mut func_ty) =
3284                infer_expr_type(unifier, supply, env, adts, known, head)?;
3285            let mut overload_name = None;
3286            let mut overload_candidates = if let Expr::Var(var) = head {
3287                if let Some(schemes) = env.lookup(&var.name) {
3288                    if schemes.len() <= 1 {
3289                        None
3290                    } else {
3291                        let mut candidates = Vec::new();
3292                        for scheme in schemes {
3293                            if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
3294                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
3295                            }
3296                            let scheme = apply_scheme_with_unifier(scheme, unifier);
3297                            let (p, typ) = instantiate(&scheme, supply);
3298                            if !p.is_empty() {
3299                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
3300                            }
3301                            candidates.push(typ);
3302                        }
3303                        overload_name = Some(var.name.clone());
3304                        Some(candidates)
3305                    }
3306                } else {
3307                    None
3308                }
3309            } else {
3310                None
3311            };
3312            for arg in args {
3313                let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
3314                    TypeKind::Fun(arg, _) => Some(arg.clone()),
3315                    _ => None,
3316                };
3317                let (p_arg, arg_ty) =
3318                    infer_app_arg_type(unifier, supply, env, adts, known, arg_hint, arg)?;
3319                let arg_ty = unifier.apply_type(&arg_ty);
3320                if let Some(candidates) = overload_candidates.take() {
3321                    let candidates = candidates
3322                        .into_iter()
3323                        .map(|t| unifier.apply_type(&t))
3324                        .collect::<Vec<_>>();
3325                    let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
3326                    if narrowed.is_empty()
3327                        && let Some(name) = &overload_name
3328                    {
3329                        return Err(TypeError::AmbiguousOverload(name.clone()));
3330                    }
3331                    overload_candidates = Some(narrowed);
3332                }
3333                let res_ty = match overload_candidates.as_ref() {
3334                    Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
3335                    _ => Type::var(supply.fresh(Some("r".into()))),
3336                };
3337                unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
3338                preds.extend(p_arg);
3339                func_ty = match overload_candidates.as_ref() {
3340                    Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
3341                    _ => unifier.apply_type(&res_ty),
3342                };
3343            }
3344            Ok((preds, func_ty))
3345        }
3346        Expr::Project(_, base, field) => {
3347            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, base)?;
3348            let base_ty = unifier.apply_type(&t1);
3349            let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
3350            let field_ty =
3351                resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
3352            Ok((p1, field_ty))
3353        }
3354        Expr::RecordUpdate(_, base, updates) => {
3355            let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
3356            let base_ty = unifier.apply_type(&t_base);
3357            let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
3358            let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
3359            let (result_ty, fields) = resolve_record_update(
3360                unifier,
3361                supply,
3362                adts,
3363                &base_ty,
3364                known_variant,
3365                &update_fields,
3366            )?;
3367            let expected: HashMap<_, _> = fields.into_iter().collect();
3368
3369            let mut preds = p_base;
3370            for (k, v) in updates {
3371                let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
3372                    field: k.clone(),
3373                    typ: result_ty.to_string(),
3374                })?;
3375                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
3376                unifier.unify(&t1, expected_ty)?;
3377                preds.extend(p1);
3378            }
3379            Ok((preds, result_ty))
3380        }
3381        Expr::Let(..) => {
3382            let mut bindings = Vec::new();
3383            let mut cur = expr;
3384            while let Expr::Let(_, v, ann, d, b) = cur {
3385                bindings.push((v.clone(), ann.clone(), d.clone()));
3386                cur = b.as_ref();
3387            }
3388
3389            let mut env_cur = env.clone();
3390            let mut known_cur = known.clone();
3391            for (v, ann, d) in bindings {
3392                let (p1, t1) = if let Some(ref ann_expr) = ann {
3393                    let mut ann_vars = HashMap::new();
3394                    let ann_ty =
3395                        type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
3396                    match d.as_ref() {
3397                        Expr::RecordUpdate(_, base, updates) => infer_record_update_type_with_hint(
3398                            unifier,
3399                            supply,
3400                            &env_cur,
3401                            adts,
3402                            &known_cur,
3403                            base.as_ref(),
3404                            updates,
3405                            &ann_ty,
3406                        )?,
3407                        _ => {
3408                            let (p1, t1) =
3409                                infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?;
3410                            unifier.unify(&t1, &ann_ty)?;
3411                            (p1, t1)
3412                        }
3413                    }
3414                } else {
3415                    infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?
3416                };
3417                let def_ty = unifier.apply_type(&t1);
3418                let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
3419                    monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
3420                } else {
3421                    let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
3422                    reject_ambiguous_scheme(&scheme)?;
3423                    scheme
3424                };
3425                env_cur.extend(v.name.clone(), scheme);
3426                if let Some(known_variant) =
3427                    known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
3428                {
3429                    known_cur.insert(
3430                        v.name.clone(),
3431                        KnownVariant {
3432                            adt: known_variant.adt,
3433                            variant: known_variant.variant,
3434                        },
3435                    );
3436                } else {
3437                    known_cur.remove(&v.name);
3438                }
3439            }
3440
3441            let (p_body, t_body) =
3442                infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, cur)?;
3443            Ok((p_body, t_body))
3444        }
3445        Expr::LetRec(_, bindings, body) => {
3446            let mut env_seed = env.clone();
3447            let mut known_seed = known.clone();
3448            let mut binding_tys = HashMap::new();
3449            for (var, _ann, _def) in bindings {
3450                let tv = Type::var(supply.fresh(Some(var.name.clone())));
3451                env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
3452                known_seed.remove(&var.name);
3453                binding_tys.insert(var.name.clone(), tv);
3454            }
3455
3456            let mut inferred = Vec::with_capacity(bindings.len());
3457            for (var, ann, def) in bindings {
3458                let (preds, def_ty) =
3459                    infer_expr_type(unifier, supply, &env_seed, adts, &known_seed, def)?;
3460                if let Some(ann) = ann {
3461                    let mut ann_vars = HashMap::new();
3462                    let ann_ty = type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
3463                    unifier.unify(&def_ty, &ann_ty)?;
3464                }
3465                let binding_ty = binding_tys
3466                    .get(&var.name)
3467                    .cloned()
3468                    .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
3469                unifier.unify(&binding_ty, &def_ty)?;
3470                let resolved_ty = unifier.apply_type(&binding_ty);
3471
3472                if let Some(known_variant) =
3473                    known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
3474                {
3475                    known_seed.insert(
3476                        var.name.clone(),
3477                        KnownVariant {
3478                            adt: known_variant.adt,
3479                            variant: known_variant.variant,
3480                        },
3481                    );
3482                } else {
3483                    known_seed.remove(&var.name);
3484                }
3485                inferred.push((var.name.clone(), preds, resolved_ty));
3486            }
3487
3488            let mut env_body = env.clone();
3489            for (name, preds, def_ty) in inferred {
3490                let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
3491                reject_ambiguous_scheme(&scheme)?;
3492                env_body.extend(name, scheme);
3493            }
3494
3495            let (p_body, t_body) =
3496                infer_expr_type(unifier, supply, &env_body, adts, &known_seed, body)?;
3497            Ok((p_body, t_body))
3498        }
3499        Expr::Ite(_, cond, then_expr, else_expr) => {
3500            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, cond)?;
3501            unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
3502            let (p2, t2) = infer_expr_type(unifier, supply, env, adts, known, then_expr)?;
3503            let (p3, t3) = infer_expr_type(unifier, supply, env, adts, known, else_expr)?;
3504            unifier.unify(&t2, &t3)?;
3505            let out_ty = unifier.apply_type(&t2);
3506            let mut preds = p1;
3507            preds.extend(p2);
3508            preds.extend(p3);
3509            Ok((preds, out_ty))
3510        }
3511        Expr::Tuple(_, elems) => {
3512            let mut preds = Vec::new();
3513            let mut types = Vec::new();
3514            for elem in elems {
3515                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
3516                preds.extend(p1);
3517                types.push(unifier.apply_type(&t1));
3518            }
3519            let tuple_ty = Type::tuple(types);
3520            Ok((preds, tuple_ty))
3521        }
3522        Expr::List(_, elems) => {
3523            let elem_tv = Type::var(supply.fresh(Some("a".into())));
3524            let mut preds = Vec::new();
3525            for elem in elems {
3526                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
3527                unifier.unify(&t1, &elem_tv)?;
3528                preds.extend(p1);
3529            }
3530            let list_ty = Type::app(
3531                Type::builtin(BuiltinTypeId::List),
3532                unifier.apply_type(&elem_tv),
3533            );
3534            Ok((preds, list_ty))
3535        }
3536        Expr::Dict(_, kvs) => {
3537            let elem_tv = Type::var(supply.fresh(Some("v".into())));
3538            let mut preds = Vec::new();
3539            for v in kvs.values() {
3540                let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
3541                unifier.unify(&t1, &elem_tv)?;
3542                preds.extend(p1);
3543            }
3544            let dict_ty = Type::app(
3545                Type::builtin(BuiltinTypeId::Dict),
3546                unifier.apply_type(&elem_tv),
3547            );
3548            Ok((preds, dict_ty))
3549        }
3550        Expr::Match(_, scrutinee, arms) => {
3551            let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, scrutinee.as_ref())?;
3552            let mut preds = p1;
3553            let res_ty = Type::var(supply.fresh(Some("match".into())));
3554            let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
3555
3556            for (pat, expr) in arms {
3557                let scrutinee_ty = unifier.apply_type(&t1);
3558                let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
3559                preds.extend(p_pat);
3560
3561                let mut env_arm = env.clone();
3562                for (name, ty) in binds {
3563                    env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
3564                }
3565                let mut known_arm = known.clone();
3566                if let Expr::Var(var) = scrutinee.as_ref() {
3567                    match pat {
3568                        Pattern::Named(_, name, _) => {
3569                            let name_sym = name.to_dotted_symbol();
3570                            if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
3571                                known_arm.insert(
3572                                    var.name.clone(),
3573                                    KnownVariant {
3574                                        adt: adt.name.clone(),
3575                                        variant: name_sym,
3576                                    },
3577                                );
3578                            } else {
3579                                known_arm.remove(&var.name);
3580                            }
3581                        }
3582                        _ => {
3583                            known_arm.remove(&var.name);
3584                        }
3585                    }
3586                }
3587                let (p_expr, t_expr) =
3588                    infer_expr_type(unifier, supply, &env_arm, adts, &known_arm, expr)?;
3589                unifier.unify(&res_ty, &t_expr)?;
3590                preds.extend(p_expr);
3591            }
3592
3593            let scrutinee_ty = unifier.apply_type(&t1);
3594            check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
3595            let out_ty = unifier.apply_type(&res_ty);
3596            Ok((preds, out_ty))
3597        }
3598        Expr::Ann(_, expr, ann) => {
3599            let ann_ty = type_from_annotation_expr(adts, ann)?;
3600            match expr.as_ref() {
3601                Expr::RecordUpdate(_, base, updates) => {
3602                    let (preds, out_ty) = infer_record_update_type_with_hint(
3603                        unifier,
3604                        supply,
3605                        env,
3606                        adts,
3607                        known,
3608                        base.as_ref(),
3609                        updates,
3610                        &ann_ty,
3611                    )?;
3612                    Ok((preds, out_ty))
3613                }
3614                _ => {
3615                    let (preds, expr_ty) =
3616                        infer_expr_type(unifier, supply, env, adts, known, expr)?;
3617                    unifier.unify(&expr_ty, &ann_ty)?;
3618                    let out_ty = unifier.apply_type(&ann_ty);
3619                    Ok((preds, out_ty))
3620                }
3621            }
3622        }
3623    }
3624}
3625
3626fn infer_expr(
3627    unifier: &mut Unifier<'_>,
3628    supply: &mut TypeVarSupply,
3629    env: &TypeEnv,
3630    adts: &HashMap<Symbol, AdtDecl>,
3631    known: &KnownVariants,
3632    expr: &Expr,
3633) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
3634    let span = *expr.span();
3635    let res = unifier.with_infer_depth(span, |unifier| {
3636        (|| {
3637            unifier.charge_infer_node()?;
3638            match expr {
3639                Expr::Bool(_, v) => {
3640                    let t = Type::builtin(BuiltinTypeId::Bool);
3641                    Ok((
3642                        vec![],
3643                        t.clone(),
3644                        TypedExpr::new(t, TypedExprKind::Bool(*v)),
3645                    ))
3646                }
3647                Expr::Uint(_, v) => {
3648                    let t = Type::var(supply.fresh(Some(sym("n"))));
3649                    Ok((
3650                        vec![Predicate::new("Integral", t.clone())],
3651                        t.clone(),
3652                        TypedExpr::new(t, TypedExprKind::Uint(*v)),
3653                    ))
3654                }
3655                Expr::Int(_, v) => {
3656                    let t = Type::var(supply.fresh(Some(sym("n"))));
3657                    Ok((
3658                        vec![
3659                            Predicate::new("Integral", t.clone()),
3660                            Predicate::new("AdditiveGroup", t.clone()),
3661                        ],
3662                        t.clone(),
3663                        TypedExpr::new(t, TypedExprKind::Int(*v)),
3664                    ))
3665                }
3666                Expr::Float(_, v) => {
3667                    let t = Type::builtin(BuiltinTypeId::F32);
3668                    Ok((
3669                        vec![],
3670                        t.clone(),
3671                        TypedExpr::new(t, TypedExprKind::Float(*v)),
3672                    ))
3673                }
3674                Expr::String(_, v) => {
3675                    let t = Type::builtin(BuiltinTypeId::String);
3676                    Ok((
3677                        vec![],
3678                        t.clone(),
3679                        TypedExpr::new(t, TypedExprKind::String(v.clone())),
3680                    ))
3681                }
3682                Expr::Uuid(_, v) => {
3683                    let t = Type::builtin(BuiltinTypeId::Uuid);
3684                    Ok((
3685                        vec![],
3686                        t.clone(),
3687                        TypedExpr::new(t, TypedExprKind::Uuid(*v)),
3688                    ))
3689                }
3690                Expr::DateTime(_, v) => {
3691                    let t = Type::builtin(BuiltinTypeId::DateTime);
3692                    Ok((
3693                        vec![],
3694                        t.clone(),
3695                        TypedExpr::new(t, TypedExprKind::DateTime(*v)),
3696                    ))
3697                }
3698                Expr::Hole(_) => {
3699                    let t = Type::var(supply.fresh(Some(sym("hole"))));
3700                    Ok((vec![], t.clone(), TypedExpr::new(t, TypedExprKind::Hole)))
3701                }
3702                Expr::Var(var) => {
3703                    let schemes = env
3704                        .lookup(&var.name)
3705                        .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
3706                    if schemes.len() == 1 {
3707                        let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
3708                        let (preds, t) = instantiate(&scheme, supply);
3709                        let typed = TypedExpr::new(
3710                            t.clone(),
3711                            TypedExprKind::Var {
3712                                name: var.name.clone(),
3713                                overloads: vec![],
3714                            },
3715                        );
3716                        Ok((preds, t, typed))
3717                    } else {
3718                        let mut overloads = Vec::new();
3719                        for scheme in schemes {
3720                            // Overloads in Rex are a *type-directed* choice at use sites.
3721                            //
3722                            // We can represent overload sets whose alternatives differ only
3723                            // by type (e.g. `prim_map` for List/Array/Option/Result). But we
3724                            // do *not* model “choice between predicate sets”: that would
3725                            // require disjunction in the constraint solver.
3726                            if !scheme.preds.is_empty() {
3727                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
3728                            }
3729
3730                            let scheme = apply_scheme_with_unifier(scheme, unifier);
3731                            let (preds, typ) = instantiate(&scheme, supply);
3732                            if !preds.is_empty() {
3733                                return Err(TypeError::AmbiguousOverload(var.name.clone()));
3734                            }
3735                            overloads.push(typ);
3736                        }
3737                        let t = Type::var(supply.fresh(Some(var.name.clone())));
3738                        let typed = TypedExpr::new(
3739                            t.clone(),
3740                            TypedExprKind::Var {
3741                                name: var.name.clone(),
3742                                overloads,
3743                            },
3744                        );
3745                        Ok((vec![], t, typed))
3746                    }
3747                }
3748                Expr::Lam(..) => {
3749                    let (params, constraints, body) = collect_lambda_chain(expr);
3750                    let mut ann_vars = HashMap::new();
3751                    let mut param_tys = Vec::with_capacity(params.len());
3752                    for (name, ann) in &params {
3753                        let param_ty = match ann {
3754                            Some(ann) => {
3755                                type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?
3756                            }
3757                            None => Type::var(supply.fresh(Some(name.clone()))),
3758                        };
3759                        param_tys.push((name.clone(), param_ty));
3760                    }
3761
3762                    let mut env1 = env.clone();
3763                    let mut known_body = known.clone();
3764                    for (name, param_ty) in &param_tys {
3765                        env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
3766                        known_body.remove(name);
3767                    }
3768
3769                    let (mut preds, body_ty, typed_body) =
3770                        infer_expr(unifier, supply, &env1, adts, &known_body, body)?;
3771                    let constraint_preds =
3772                        predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
3773                    preds.extend(constraint_preds);
3774
3775                    let mut typed = typed_body;
3776                    let mut fun_ty = unifier.apply_type(&body_ty);
3777                    for (name, param_ty) in param_tys.iter().rev() {
3778                        fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
3779                        typed = TypedExpr::new(
3780                            fun_ty.clone(),
3781                            TypedExprKind::Lam {
3782                                param: name.clone(),
3783                                body: Box::new(typed),
3784                            },
3785                        );
3786                    }
3787
3788                    Ok((preds, fun_ty, typed))
3789                }
3790                Expr::App(..) => {
3791                    let (head, args) = collect_app_chain(expr);
3792                    let (mut preds, mut func_ty, mut typed) =
3793                        infer_expr(unifier, supply, env, adts, known, head)?;
3794                    let mut overload_name = None;
3795                    let mut overload_candidates = match &typed.kind {
3796                        TypedExprKind::Var { name, overloads } if !overloads.is_empty() => {
3797                            overload_name = Some(name.clone());
3798                            Some(overloads.clone())
3799                        }
3800                        _ => None,
3801                    };
3802                    for arg in args {
3803                        let expected_arg = match unifier.apply_type(&func_ty).as_ref() {
3804                            TypeKind::Fun(arg, _) => Some(arg.clone()),
3805                            _ => None,
3806                        };
3807                        let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
3808                            TypeKind::Fun(arg, _) => Some(arg.clone()),
3809                            _ => None,
3810                        };
3811                        let (p_arg, arg_ty, typed_arg) =
3812                            infer_app_arg_typed(unifier, supply, env, adts, known, arg_hint, arg)?;
3813                        let mut arg_ty = unifier.apply_type(&arg_ty);
3814                        let mut typed_arg = typed_arg;
3815
3816                        // Narrow implicit coercion: in function argument position only,
3817                        // coerce `List a` to `Array a` when the callee expects `Array a`.
3818                        if let Some(expected_arg) = expected_arg {
3819                            let expected_arg = unifier.apply_type(&expected_arg);
3820                            if let (Some(expected_elem), Some(arg_elem)) = (
3821                                unary_app_arg(&expected_arg, "Array"),
3822                                unary_app_arg(&arg_ty, "List"),
3823                            ) {
3824                                unifier.unify(&expected_elem, &arg_elem)?;
3825                                let elem_ty = unifier.apply_type(&expected_elem);
3826                                let list_ty = Type::list(elem_ty.clone());
3827                                let array_ty = Type::array(elem_ty);
3828                                let coercion_ty = Type::fun(list_ty, array_ty.clone());
3829                                let coercion_fn = TypedExpr::new(
3830                                    coercion_ty,
3831                                    TypedExprKind::Var {
3832                                        name: sym("prim_array_from_list"),
3833                                        overloads: vec![],
3834                                    },
3835                                );
3836                                typed_arg = TypedExpr::new(
3837                                    array_ty.clone(),
3838                                    TypedExprKind::App(Box::new(coercion_fn), Box::new(typed_arg)),
3839                                );
3840                                arg_ty = array_ty;
3841                            }
3842                        }
3843                        if let Some(candidates) = overload_candidates.take() {
3844                            let candidates = candidates
3845                                .into_iter()
3846                                .map(|t| unifier.apply_type(&t))
3847                                .collect::<Vec<_>>();
3848                            let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
3849                            if narrowed.is_empty()
3850                                && let Some(name) = &overload_name
3851                            {
3852                                return Err(TypeError::AmbiguousOverload(name.clone()));
3853                            }
3854                            overload_candidates = Some(narrowed);
3855                        }
3856                        let res_ty = match overload_candidates.as_ref() {
3857                            Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
3858                            _ => Type::var(supply.fresh(Some("r".into()))),
3859                        };
3860                        unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
3861                        let result_ty = match overload_candidates.as_ref() {
3862                            Some(candidates) if candidates.len() == 1 => {
3863                                unifier.apply_type(&candidates[0])
3864                            }
3865                            _ => unifier.apply_type(&res_ty),
3866                        };
3867                        preds.extend(p_arg);
3868                        typed = TypedExpr::new(
3869                            result_ty.clone(),
3870                            TypedExprKind::App(Box::new(typed), Box::new(typed_arg)),
3871                        );
3872                        func_ty = result_ty;
3873                    }
3874                    Ok((preds, func_ty, typed))
3875                }
3876                Expr::Project(_, base, field) => {
3877                    let (p1, t1, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
3878                    let base_ty = unifier.apply_type(&t1);
3879                    let known_variant =
3880                        known_variant_from_expr_with_known(base, &base_ty, adts, known);
3881                    let field_ty =
3882                        resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
3883                    let typed = TypedExpr::new(
3884                        field_ty.clone(),
3885                        TypedExprKind::Project {
3886                            expr: Box::new(typed_base),
3887                            field: field.clone(),
3888                        },
3889                    );
3890                    Ok((p1, field_ty, typed))
3891                }
3892                Expr::RecordUpdate(_, base, updates) => {
3893                    let (p_base, t_base, typed_base) =
3894                        infer_expr(unifier, supply, env, adts, known, base)?;
3895                    let base_ty = unifier.apply_type(&t_base);
3896                    let known_variant =
3897                        known_variant_from_expr_with_known(base, &base_ty, adts, known);
3898                    let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
3899                    let (result_ty, fields) = resolve_record_update(
3900                        unifier,
3901                        supply,
3902                        adts,
3903                        &base_ty,
3904                        known_variant,
3905                        &update_fields,
3906                    )?;
3907                    let expected: HashMap<_, _> = fields.into_iter().collect();
3908
3909                    let mut preds = p_base;
3910                    let mut typed_updates = BTreeMap::new();
3911                    for (k, v) in updates {
3912                        let expected_ty =
3913                            expected.get(k).ok_or_else(|| TypeError::UnknownField {
3914                                field: k.clone(),
3915                                typ: result_ty.to_string(),
3916                            })?;
3917                        let (p1, t1, typed_v) =
3918                            infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
3919                        unifier.unify(&t1, expected_ty)?;
3920                        preds.extend(p1);
3921                        typed_updates.insert(k.clone(), typed_v);
3922                    }
3923                    let typed = TypedExpr::new(
3924                        result_ty.clone(),
3925                        TypedExprKind::RecordUpdate {
3926                            base: Box::new(typed_base),
3927                            updates: typed_updates,
3928                        },
3929                    );
3930                    Ok((preds, result_ty, typed))
3931                }
3932                Expr::Let(..) => {
3933                    let mut bindings = Vec::new();
3934                    let mut cur = expr;
3935                    while let Expr::Let(_, v, ann, d, b) = cur {
3936                        bindings.push((v.clone(), ann.clone(), d.clone()));
3937                        cur = b.as_ref();
3938                    }
3939
3940                    let mut env_cur = env.clone();
3941                    let mut known_cur = known.clone();
3942                    let mut typed_defs = Vec::new();
3943                    for (v, ann, d) in bindings {
3944                        let (p1, t1, typed_def) = if let Some(ref ann_expr) = ann {
3945                            let mut ann_vars = HashMap::new();
3946                            let ann_ty = type_from_annotation_expr_vars(
3947                                adts,
3948                                ann_expr,
3949                                &mut ann_vars,
3950                                supply,
3951                            )?;
3952                            match d.as_ref() {
3953                                Expr::RecordUpdate(_, base, updates) => {
3954                                    infer_record_update_typed_with_hint(
3955                                        unifier,
3956                                        supply,
3957                                        &env_cur,
3958                                        adts,
3959                                        &known_cur,
3960                                        base.as_ref(),
3961                                        updates,
3962                                        &ann_ty,
3963                                    )?
3964                                }
3965                                _ => {
3966                                    let (p1, t1, typed_def) = infer_expr(
3967                                        unifier, supply, &env_cur, adts, &known_cur, &d,
3968                                    )?;
3969                                    unifier.unify(&t1, &ann_ty)?;
3970                                    (p1, t1, typed_def)
3971                                }
3972                            }
3973                        } else {
3974                            infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?
3975                        };
3976                        let def_ty = unifier.apply_type(&t1);
3977                        let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
3978                            monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
3979                        } else {
3980                            let scheme =
3981                                generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
3982                            reject_ambiguous_scheme(&scheme)?;
3983                            scheme
3984                        };
3985                        env_cur.extend(v.name.clone(), scheme);
3986                        if let Some(known_variant) =
3987                            known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
3988                        {
3989                            known_cur.insert(
3990                                v.name.clone(),
3991                                KnownVariant {
3992                                    adt: known_variant.adt,
3993                                    variant: known_variant.variant,
3994                                },
3995                            );
3996                        } else {
3997                            known_cur.remove(&v.name);
3998                        }
3999                        typed_defs.push((v.name.clone(), typed_def));
4000                    }
4001
4002                    let (p_body, t_body, typed_body) =
4003                        infer_expr(unifier, supply, &env_cur, adts, &known_cur, cur)?;
4004
4005                    let mut typed = typed_body;
4006                    for (name, def) in typed_defs.into_iter().rev() {
4007                        typed = TypedExpr::new(
4008                            t_body.clone(),
4009                            TypedExprKind::Let {
4010                                name,
4011                                def: Box::new(def),
4012                                body: Box::new(typed),
4013                            },
4014                        );
4015                    }
4016                    Ok((p_body, t_body, typed))
4017                }
4018                Expr::LetRec(_, bindings, body) => {
4019                    let mut env_seed = env.clone();
4020                    let mut known_seed = known.clone();
4021                    let mut binding_tys = HashMap::new();
4022                    for (var, _ann, _def) in bindings {
4023                        let tv = Type::var(supply.fresh(Some(var.name.clone())));
4024                        env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
4025                        known_seed.remove(&var.name);
4026                        binding_tys.insert(var.name.clone(), tv);
4027                    }
4028
4029                    let mut inferred_defs = Vec::with_capacity(bindings.len());
4030                    for (var, ann, def) in bindings {
4031                        let (preds, def_ty, typed_def) =
4032                            infer_expr(unifier, supply, &env_seed, adts, &known_seed, def)?;
4033                        if let Some(ann) = ann {
4034                            let mut ann_vars = HashMap::new();
4035                            let ann_ty =
4036                                type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
4037                            unifier.unify(&def_ty, &ann_ty)?;
4038                        }
4039                        let binding_ty = binding_tys
4040                            .get(&var.name)
4041                            .cloned()
4042                            .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
4043                        unifier.unify(&binding_ty, &def_ty)?;
4044                        let resolved_ty = unifier.apply_type(&binding_ty);
4045
4046                        if let Some(known_variant) =
4047                            known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
4048                        {
4049                            known_seed.insert(
4050                                var.name.clone(),
4051                                KnownVariant {
4052                                    adt: known_variant.adt,
4053                                    variant: known_variant.variant,
4054                                },
4055                            );
4056                        } else {
4057                            known_seed.remove(&var.name);
4058                        }
4059                        inferred_defs.push((var.name.clone(), preds, resolved_ty, typed_def));
4060                    }
4061
4062                    let mut env_body = env.clone();
4063                    let mut typed_bindings = Vec::with_capacity(inferred_defs.len());
4064                    for (name, preds, def_ty, typed_def) in inferred_defs {
4065                        let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
4066                        reject_ambiguous_scheme(&scheme)?;
4067                        env_body.extend(name.clone(), scheme);
4068                        typed_bindings.push((name, typed_def));
4069                    }
4070
4071                    let (p_body, t_body, typed_body) =
4072                        infer_expr(unifier, supply, &env_body, adts, &known_seed, body)?;
4073                    let typed = TypedExpr::new(
4074                        t_body.clone(),
4075                        TypedExprKind::LetRec {
4076                            bindings: typed_bindings,
4077                            body: Box::new(typed_body),
4078                        },
4079                    );
4080                    Ok((p_body, t_body, typed))
4081                }
4082                Expr::Ite(_, cond, then_expr, else_expr) => {
4083                    let (p1, t1, typed_cond) = infer_expr(unifier, supply, env, adts, known, cond)?;
4084                    unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
4085                    let (p2, t2, typed_then) =
4086                        infer_expr(unifier, supply, env, adts, known, then_expr)?;
4087                    let (p3, t3, typed_else) =
4088                        infer_expr(unifier, supply, env, adts, known, else_expr)?;
4089                    unifier.unify(&t2, &t3)?;
4090                    let out_ty = unifier.apply_type(&t2);
4091                    let mut preds = p1;
4092                    preds.extend(p2);
4093                    preds.extend(p3);
4094                    let typed = TypedExpr::new(
4095                        out_ty.clone(),
4096                        TypedExprKind::Ite {
4097                            cond: Box::new(typed_cond),
4098                            then_expr: Box::new(typed_then),
4099                            else_expr: Box::new(typed_else),
4100                        },
4101                    );
4102                    Ok((preds, out_ty, typed))
4103                }
4104                Expr::Tuple(_, elems) => {
4105                    let mut preds = Vec::new();
4106                    let mut types = Vec::new();
4107                    let mut typed_elems = Vec::new();
4108                    for elem in elems {
4109                        let (p1, t1, typed_elem) =
4110                            infer_expr(unifier, supply, env, adts, known, elem)?;
4111                        preds.extend(p1);
4112                        types.push(unifier.apply_type(&t1));
4113                        typed_elems.push(typed_elem);
4114                    }
4115                    let tuple_ty = Type::tuple(types);
4116                    let typed = TypedExpr::new(tuple_ty.clone(), TypedExprKind::Tuple(typed_elems));
4117                    Ok((preds, tuple_ty, typed))
4118                }
4119                Expr::List(_, elems) => {
4120                    let elem_tv = Type::var(supply.fresh(Some("a".into())));
4121                    let mut preds = Vec::new();
4122                    let mut typed_elems = Vec::new();
4123                    for elem in elems {
4124                        let (p1, t1, typed_elem) =
4125                            infer_expr(unifier, supply, env, adts, known, elem)?;
4126                        unifier.unify(&t1, &elem_tv)?;
4127                        preds.extend(p1);
4128                        typed_elems.push(typed_elem);
4129                    }
4130                    let list_ty = Type::app(
4131                        Type::builtin(BuiltinTypeId::List),
4132                        unifier.apply_type(&elem_tv),
4133                    );
4134                    let typed = TypedExpr::new(list_ty.clone(), TypedExprKind::List(typed_elems));
4135                    Ok((preds, list_ty, typed))
4136                }
4137                Expr::Dict(_, kvs) => {
4138                    let elem_tv = Type::var(supply.fresh(Some("v".into())));
4139                    let mut preds = Vec::new();
4140                    let mut typed_kvs = BTreeMap::new();
4141                    for (k, v) in kvs {
4142                        let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v)?;
4143                        unifier.unify(&t1, &elem_tv)?;
4144                        preds.extend(p1);
4145                        typed_kvs.insert(k.clone(), typed_v);
4146                    }
4147                    let dict_ty = Type::app(
4148                        Type::builtin(BuiltinTypeId::Dict),
4149                        unifier.apply_type(&elem_tv),
4150                    );
4151                    let typed = TypedExpr::new(dict_ty.clone(), TypedExprKind::Dict(typed_kvs));
4152                    Ok((preds, dict_ty, typed))
4153                }
4154                Expr::Match(_, scrutinee, arms) => {
4155                    let (p1, t1, typed_scrutinee) =
4156                        infer_expr(unifier, supply, env, adts, known, scrutinee)?;
4157                    let mut preds = p1;
4158                    let mut typed_arms = Vec::new();
4159                    let res_ty = Type::var(supply.fresh(Some("match".into())));
4160                    let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
4161
4162                    for (pat, expr) in arms {
4163                        let scrutinee_ty = unifier.apply_type(&t1);
4164                        let (p_pat, binds) =
4165                            infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
4166                        preds.extend(p_pat);
4167
4168                        let mut env_arm = env.clone();
4169                        for (name, ty) in binds {
4170                            env_arm
4171                                .extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
4172                        }
4173                        let mut known_arm = known.clone();
4174                        if let Expr::Var(var) = scrutinee.as_ref() {
4175                            match pat {
4176                                Pattern::Named(_, name, _) => {
4177                                    let name_sym = name.to_dotted_symbol();
4178                                    if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
4179                                        known_arm.insert(
4180                                            var.name.clone(),
4181                                            KnownVariant {
4182                                                adt: adt.name.clone(),
4183                                                variant: name_sym,
4184                                            },
4185                                        );
4186                                    } else {
4187                                        known_arm.remove(&var.name);
4188                                    }
4189                                }
4190                                _ => {
4191                                    known_arm.remove(&var.name);
4192                                }
4193                            }
4194                        }
4195                        let (p_expr, t_expr, typed_expr) =
4196                            infer_expr(unifier, supply, &env_arm, adts, &known_arm, expr)?;
4197                        unifier.unify(&res_ty, &t_expr)?;
4198                        preds.extend(p_expr);
4199                        typed_arms.push((pat.clone(), typed_expr));
4200                    }
4201
4202                    let scrutinee_ty = unifier.apply_type(&t1);
4203                    check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
4204                    let out_ty = unifier.apply_type(&res_ty);
4205                    let typed = TypedExpr::new(
4206                        out_ty.clone(),
4207                        TypedExprKind::Match {
4208                            scrutinee: Box::new(typed_scrutinee),
4209                            arms: typed_arms,
4210                        },
4211                    );
4212                    Ok((preds, out_ty, typed))
4213                }
4214                Expr::Ann(_, expr, ann) => {
4215                    let ann_ty = type_from_annotation_expr(adts, ann)?;
4216                    match expr.as_ref() {
4217                        Expr::RecordUpdate(_, base, updates) => {
4218                            infer_record_update_typed_with_hint(
4219                                unifier,
4220                                supply,
4221                                env,
4222                                adts,
4223                                known,
4224                                base.as_ref(),
4225                                updates,
4226                                &ann_ty,
4227                            )
4228                        }
4229                        _ => {
4230                            let (preds, expr_ty, typed_expr) =
4231                                infer_expr(unifier, supply, env, adts, known, expr)?;
4232                            unifier.unify(&expr_ty, &ann_ty)?;
4233                            let out_ty = unifier.apply_type(&ann_ty);
4234                            Ok((preds, out_ty, typed_expr))
4235                        }
4236                    }
4237                }
4238            }
4239        })()
4240    });
4241    res.map_err(|err| with_span(&span, err))
4242}
4243
4244fn ctor_lookup<'a>(
4245    adts: &'a HashMap<Symbol, AdtDecl>,
4246    name: &Symbol,
4247) -> Option<(&'a AdtDecl, &'a AdtVariant)> {
4248    let mut found = None;
4249    for adt in adts.values() {
4250        if let Some(variant) = adt.variants.iter().find(|v| &v.name == name) {
4251            if found.is_some() {
4252                return None;
4253            }
4254            found = Some((adt, variant));
4255        }
4256    }
4257    found
4258}
4259
4260fn record_fields(variant: &AdtVariant) -> Option<&[(Symbol, Type)]> {
4261    if variant.args.len() != 1 {
4262        return None;
4263    }
4264    match variant.args[0].as_ref() {
4265        TypeKind::Record(fields) => Some(fields),
4266        _ => None,
4267    }
4268}
4269
4270fn instantiate_variant_fields(
4271    adt: &AdtDecl,
4272    variant: &AdtVariant,
4273    supply: &mut TypeVarSupply,
4274) -> Option<(Type, Vec<(Symbol, Type)>)> {
4275    let fields = record_fields(variant)?;
4276    let mut subst = Subst::new_sync();
4277    for param in &adt.params {
4278        let fresh = Type::var(supply.fresh(param.var.name.clone()));
4279        subst = subst.insert(param.var.id, fresh);
4280    }
4281    let result_ty = adt.result_type().apply(&subst);
4282    let fields = fields
4283        .iter()
4284        .map(|(name, ty)| (name.clone(), ty.apply(&subst)))
4285        .collect();
4286    Some((result_ty, fields))
4287}
4288
4289fn known_variant_from_expr(
4290    expr: &Expr,
4291    expr_ty: &Type,
4292    adts: &HashMap<Symbol, AdtDecl>,
4293) -> Option<KnownVariant> {
4294    let mut expr = expr;
4295    while let Expr::Ann(_, inner, _) = expr {
4296        expr = inner.as_ref();
4297    }
4298    if matches!(expr_ty.as_ref(), TypeKind::Fun(..)) {
4299        return None;
4300    }
4301    let ctor = match expr {
4302        Expr::App(_, f, _) => match f.as_ref() {
4303            Expr::Var(var) => var.name.clone(),
4304            _ => return None,
4305        },
4306        _ => return None,
4307    };
4308    let (adt, variant) = ctor_lookup(adts, &ctor)?;
4309    record_fields(variant)?;
4310    Some(KnownVariant {
4311        adt: adt.name.clone(),
4312        variant: variant.name.clone(),
4313    })
4314}
4315
4316fn known_variant_from_expr_with_known(
4317    expr: &Expr,
4318    expr_ty: &Type,
4319    adts: &HashMap<Symbol, AdtDecl>,
4320    known: &KnownVariants,
4321) -> Option<KnownVariant> {
4322    let mut expr = expr;
4323    while let Expr::Ann(_, inner, _) = expr {
4324        expr = inner.as_ref();
4325    }
4326    match expr {
4327        Expr::Var(var) => known.get(&var.name).cloned(),
4328        Expr::RecordUpdate(_, base, _) => {
4329            known_variant_from_expr_with_known(base.as_ref(), expr_ty, adts, known)
4330        }
4331        _ => known_variant_from_expr(expr, expr_ty, adts),
4332    }
4333}
4334
4335fn select_record_variant<'a, F>(
4336    adts: &'a HashMap<Symbol, AdtDecl>,
4337    base_ty: &Type,
4338    known_variant: Option<KnownVariant>,
4339    field_for_errors: &Symbol,
4340    matches_fields: F,
4341) -> Result<(&'a AdtDecl, &'a AdtVariant), TypeError>
4342where
4343    F: Fn(&[(Symbol, Type)]) -> bool,
4344{
4345    if let Some(info) = known_variant {
4346        let adt = adts
4347            .get(&info.adt)
4348            .ok_or_else(|| TypeError::UnknownTypeName(info.adt.clone()))?;
4349        let variant = adt
4350            .variants
4351            .iter()
4352            .find(|v| v.name == info.variant)
4353            .ok_or_else(|| TypeError::UnknownField {
4354                field: field_for_errors.clone(),
4355                typ: base_ty.to_string(),
4356            })?;
4357        return Ok((adt, variant));
4358    }
4359
4360    if let Some(adt_name) = type_head_name(base_ty) {
4361        let adt = adts.get(adt_name).ok_or_else(|| TypeError::UnknownField {
4362            field: field_for_errors.clone(),
4363            typ: base_ty.to_string(),
4364        })?;
4365        if adt.variants.len() == 1 {
4366            return Ok((adt, &adt.variants[0]));
4367        }
4368        return Err(TypeError::FieldNotKnown {
4369            field: field_for_errors.clone(),
4370            typ: base_ty.to_string(),
4371        });
4372    }
4373
4374    if matches!(base_ty.as_ref(), TypeKind::Var(_)) {
4375        let mut candidates = Vec::new();
4376        for adt in adts.values() {
4377            if adt.variants.len() != 1 {
4378                continue;
4379            }
4380            let variant = &adt.variants[0];
4381            let Some(fields) = record_fields(variant) else {
4382                continue;
4383            };
4384            if matches_fields(fields) {
4385                candidates.push((adt, variant));
4386            }
4387        }
4388        if candidates.len() == 1 {
4389            return Ok(candidates.remove(0));
4390        }
4391        if candidates.is_empty() {
4392            return Err(TypeError::UnknownField {
4393                field: field_for_errors.clone(),
4394                typ: base_ty.to_string(),
4395            });
4396        }
4397        return Err(TypeError::FieldNotKnown {
4398            field: field_for_errors.clone(),
4399            typ: base_ty.to_string(),
4400        });
4401    }
4402
4403    Err(TypeError::UnknownField {
4404        field: field_for_errors.clone(),
4405        typ: base_ty.to_string(),
4406    })
4407}
4408
4409fn resolve_record_update(
4410    unifier: &mut Unifier<'_>,
4411    supply: &mut TypeVarSupply,
4412    adts: &HashMap<Symbol, AdtDecl>,
4413    base_ty: &Type,
4414    known_variant: Option<KnownVariant>,
4415    update_fields: &[Symbol],
4416) -> Result<(Type, Vec<(Symbol, Type)>), TypeError> {
4417    if let TypeKind::Record(fields) = base_ty.as_ref() {
4418        return Ok((base_ty.clone(), fields.clone()));
4419    }
4420
4421    let field_for_errors = update_fields.first().cloned().unwrap_or_else(|| sym("_"));
4422
4423    let (adt, variant) =
4424        select_record_variant(adts, base_ty, known_variant, &field_for_errors, |fields| {
4425            update_fields
4426                .iter()
4427                .all(|field| fields.iter().any(|(name, _)| name == field))
4428        })?;
4429
4430    let (result_ty, fields) =
4431        instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
4432            TypeError::UnknownField {
4433                field: field_for_errors.clone(),
4434                typ: base_ty.to_string(),
4435            }
4436        })?;
4437
4438    for field in update_fields {
4439        if fields.iter().all(|(name, _)| name != field) {
4440            return Err(TypeError::UnknownField {
4441                field: field.clone(),
4442                typ: base_ty.to_string(),
4443            });
4444        }
4445    }
4446
4447    unifier.unify(base_ty, &result_ty)?;
4448    let result_ty = unifier.apply_type(&result_ty);
4449    let fields = fields
4450        .into_iter()
4451        .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4452        .collect();
4453    Ok((result_ty, fields))
4454}
4455
4456fn resolve_projection(
4457    unifier: &mut Unifier<'_>,
4458    supply: &mut TypeVarSupply,
4459    adts: &HashMap<Symbol, AdtDecl>,
4460    base_ty: &Type,
4461    known_variant: Option<KnownVariant>,
4462    field: &Symbol,
4463) -> Result<Type, TypeError> {
4464    if let Ok(index) = field.as_ref().parse::<usize>() {
4465        let elem_ty = match base_ty.as_ref() {
4466            TypeKind::Tuple(elems) => {
4467                elems
4468                    .get(index)
4469                    .cloned()
4470                    .ok_or_else(|| TypeError::UnknownField {
4471                        field: field.clone(),
4472                        typ: base_ty.to_string(),
4473                    })?
4474            }
4475            TypeKind::Var(_) => {
4476                let mut elems = Vec::with_capacity(index + 1);
4477                for _ in 0..=index {
4478                    elems.push(Type::var(supply.fresh(Some(sym("t")))));
4479                }
4480                let tuple_ty = Type::tuple(elems.clone());
4481                unifier.unify(base_ty, &tuple_ty)?;
4482                elems[index].clone()
4483            }
4484            _ => {
4485                return Err(TypeError::UnknownField {
4486                    field: field.clone(),
4487                    typ: base_ty.to_string(),
4488                });
4489            }
4490        };
4491        return Ok(unifier.apply_type(&elem_ty));
4492    }
4493
4494    let (adt, variant) = select_record_variant(adts, base_ty, known_variant, field, |fields| {
4495        fields.iter().any(|(name, _)| name == field)
4496    })?;
4497
4498    let (result_ty, fields) =
4499        instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
4500            TypeError::UnknownField {
4501                field: field.clone(),
4502                typ: base_ty.to_string(),
4503            }
4504        })?;
4505    let field_ty = fields
4506        .iter()
4507        .find(|(name, _)| name == field)
4508        .map(|(_, ty)| ty.clone())
4509        .ok_or_else(|| TypeError::UnknownField {
4510            field: field.clone(),
4511            typ: base_ty.to_string(),
4512        })?;
4513    unifier.unify(base_ty, &result_ty)?;
4514    Ok(unifier.apply_type(&field_ty))
4515}
4516
4517fn decompose_fun(typ: &Type, arity: usize) -> Option<(Vec<Type>, Type)> {
4518    let mut args = Vec::with_capacity(arity);
4519    let mut cur = typ.clone();
4520    for _ in 0..arity {
4521        match cur.as_ref() {
4522            TypeKind::Fun(a, b) => {
4523                args.push(a.clone());
4524                cur = b.clone();
4525            }
4526            _ => return None,
4527        }
4528    }
4529    Some((args, cur))
4530}
4531
4532type InferPatternResult = (Vec<Predicate>, Vec<(Symbol, Type)>);
4533
4534fn infer_pattern(
4535    unifier: &mut Unifier<'_>,
4536    supply: &mut TypeVarSupply,
4537    env: &TypeEnv,
4538    pat: &Pattern,
4539    scrutinee_ty: &Type,
4540) -> Result<InferPatternResult, TypeError> {
4541    let span = *pat.span();
4542    let res = (|| {
4543        unifier.charge_infer_node()?;
4544        match pat {
4545            Pattern::Wildcard(..) => Ok((vec![], vec![])),
4546            Pattern::Var(var) => Ok((
4547                vec![],
4548                vec![(var.name.clone(), unifier.apply_type(scrutinee_ty))],
4549            )),
4550            Pattern::Named(_, name, ps) => {
4551                let ctor_name = name.to_dotted_symbol();
4552                let schemes = env
4553                    .lookup(&ctor_name)
4554                    .ok_or_else(|| TypeError::UnknownVar(ctor_name.clone()))?;
4555                if schemes.len() != 1 {
4556                    return Err(TypeError::AmbiguousOverload(ctor_name));
4557                }
4558                let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
4559                let (preds, ctor_ty) = instantiate(&scheme, supply);
4560                let (arg_tys, res_ty) = decompose_fun(&ctor_ty, ps.len())
4561                    .ok_or(TypeError::UnsupportedExpr("pattern constructor"))?;
4562                unifier.unify(&res_ty, scrutinee_ty)?;
4563                let mut all_preds = preds;
4564                let mut bindings = Vec::new();
4565                for (p, arg_ty) in ps.iter().zip(arg_tys.iter()) {
4566                    let arg_ty = unifier.apply_type(arg_ty);
4567                    let (p1, binds1) = infer_pattern(unifier, supply, env, p, &arg_ty)?;
4568                    all_preds.extend(p1);
4569                    bindings.extend(binds1);
4570                }
4571                let bindings = bindings
4572                    .into_iter()
4573                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4574                    .collect();
4575                Ok((all_preds, bindings))
4576            }
4577            Pattern::List(_, ps) => {
4578                let elem_tv = Type::var(supply.fresh(Some("a".into())));
4579                let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
4580                unifier.unify(scrutinee_ty, &list_ty)?;
4581                let mut preds = Vec::new();
4582                let mut bindings = Vec::new();
4583                for p in ps {
4584                    let elem_ty = unifier.apply_type(&elem_tv);
4585                    let (p1, binds1) = infer_pattern(unifier, supply, env, p, &elem_ty)?;
4586                    preds.extend(p1);
4587                    bindings.extend(binds1);
4588                }
4589                let bindings = bindings
4590                    .into_iter()
4591                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4592                    .collect();
4593                Ok((preds, bindings))
4594            }
4595            Pattern::Cons(_, head, tail) => {
4596                let elem_tv = Type::var(supply.fresh(Some("a".into())));
4597                let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
4598                unifier.unify(scrutinee_ty, &list_ty)?;
4599                let mut preds = Vec::new();
4600                let mut bindings = Vec::new();
4601
4602                let head_ty = unifier.apply_type(&elem_tv);
4603                let (p1, binds1) = infer_pattern(unifier, supply, env, head, &head_ty)?;
4604                preds.extend(p1);
4605                bindings.extend(binds1);
4606
4607                let tail_ty = Type::app(
4608                    Type::builtin(BuiltinTypeId::List),
4609                    unifier.apply_type(&elem_tv),
4610                );
4611                let (p2, binds2) = infer_pattern(unifier, supply, env, tail, &tail_ty)?;
4612                preds.extend(p2);
4613                bindings.extend(binds2);
4614
4615                let bindings = bindings
4616                    .into_iter()
4617                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4618                    .collect();
4619                Ok((preds, bindings))
4620            }
4621            Pattern::Tuple(_, elems) => {
4622                // Unify against a tuple type of the right arity.
4623                let mut elem_tys: Vec<Type> = (0..elems.len())
4624                    .map(|i| Type::var(supply.fresh(Some(format!("t{i}").into()))))
4625                    .collect();
4626                let expected = Type::tuple(elem_tys.clone());
4627                unifier.unify(scrutinee_ty, &expected)?;
4628                elem_tys = elem_tys
4629                    .into_iter()
4630                    .map(|t| unifier.apply_type(&t))
4631                    .collect();
4632
4633                let mut preds = Vec::new();
4634                let mut bindings = Vec::new();
4635                for (p, ty) in elems.iter().zip(elem_tys.iter()) {
4636                    let (p_preds, p_binds) = infer_pattern(unifier, supply, env, p, ty)?;
4637                    preds.extend(p_preds);
4638                    bindings.extend(p_binds);
4639                }
4640                let bindings = bindings
4641                    .into_iter()
4642                    .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4643                    .collect();
4644                Ok((preds, bindings))
4645            }
4646            Pattern::Dict(_, fields) => {
4647                if let TypeKind::Record(ty_fields) = scrutinee_ty.as_ref() {
4648                    let mut preds = Vec::new();
4649                    let mut bindings = Vec::new();
4650                    for (key, pat) in fields {
4651                        let ty = ty_fields
4652                            .iter()
4653                            .find(|(name, _)| name == key)
4654                            .map(|(_, ty)| unifier.apply_type(ty))
4655                            .ok_or_else(|| TypeError::UnknownField {
4656                                field: key.clone(),
4657                                typ: scrutinee_ty.to_string(),
4658                            })?;
4659                        let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &ty)?;
4660                        preds.extend(p_preds);
4661                        bindings.extend(p_binds);
4662                    }
4663                    let bindings = bindings
4664                        .into_iter()
4665                        .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4666                        .collect();
4667                    Ok((preds, bindings))
4668                } else {
4669                    let elem_tv = Type::var(supply.fresh(Some("v".into())));
4670                    let dict_ty = Type::app(Type::builtin(BuiltinTypeId::Dict), elem_tv.clone());
4671                    unifier.unify(scrutinee_ty, &dict_ty)?;
4672                    let elem_ty = unifier.apply_type(&elem_tv);
4673
4674                    let mut preds = Vec::new();
4675                    let mut bindings = Vec::new();
4676                    for (_key, pat) in fields {
4677                        let (p_preds, p_binds) =
4678                            infer_pattern(unifier, supply, env, pat, &elem_ty)?;
4679                        preds.extend(p_preds);
4680                        bindings.extend(p_binds);
4681                    }
4682                    let bindings = bindings
4683                        .into_iter()
4684                        .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4685                        .collect();
4686                    Ok((preds, bindings))
4687                }
4688            }
4689        }
4690    })();
4691    res.map_err(|err| with_span(&span, err))
4692}
4693
4694fn type_head_name(typ: &Type) -> Option<&Symbol> {
4695    let mut cur = typ;
4696    while let TypeKind::App(head, _) = cur.as_ref() {
4697        cur = head;
4698    }
4699    match cur.as_ref() {
4700        TypeKind::Con(tc) => Some(&tc.name),
4701        _ => None,
4702    }
4703}
4704
4705#[derive(Clone, Debug, PartialEq, Eq)]
4706pub struct AdtConflict {
4707    pub name: Symbol,
4708    pub definitions: Vec<Type>,
4709}
4710
4711#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
4712#[error("conflicting ADT definitions: {conflicts:?}")]
4713pub struct CollectAdtsError {
4714    pub conflicts: Vec<AdtConflict>,
4715}
4716
4717/// Collect all user-defined ADT constructors referenced by the provided types.
4718///
4719/// This walks each type recursively (including nested occurrences), returns a
4720/// deduplicated list of constructor heads, and rejects ambiguous constructor
4721/// names that appear with incompatible definitions.
4722///
4723/// The returned `Type`s are constructor heads (for example `Foo`), suitable
4724/// for passing to embedder utilities that derive `AdtDecl`s from type
4725/// constructors.
4726///
4727/// # Examples
4728///
4729/// ```rust,ignore
4730/// use rex_ts::{collect_adts_in_types, BuiltinTypeId, Type};
4731///
4732/// let types = vec![
4733///     Type::app(Type::user_con("Foo", 1), Type::builtin(BuiltinTypeId::I32)),
4734///     Type::fun(Type::user_con("Bar", 0), Type::user_con("Foo", 1)),
4735/// ];
4736///
4737/// let adts = collect_adts_in_types(types).unwrap();
4738/// assert_eq!(adts, vec![Type::user_con("Foo", 1), Type::user_con("Bar", 0)]);
4739/// ```
4740///
4741/// ```rust,ignore
4742/// use rex_ts::{collect_adts_in_types, Type};
4743///
4744/// let err = collect_adts_in_types(vec![
4745///     Type::user_con("Thing", 1),
4746///     Type::user_con("Thing", 2),
4747/// ])
4748/// .unwrap_err();
4749///
4750/// assert_eq!(err.conflicts.len(), 1);
4751/// assert_eq!(err.conflicts[0].name.as_ref(), "Thing");
4752/// ```
4753pub fn collect_adts_in_types(types: Vec<Type>) -> Result<Vec<Type>, CollectAdtsError> {
4754    fn visit(
4755        typ: &Type,
4756        out: &mut Vec<Type>,
4757        seen: &mut HashSet<Type>,
4758        defs_by_name: &mut BTreeMap<Symbol, Vec<Type>>,
4759    ) {
4760        match typ.as_ref() {
4761            TypeKind::Var(_) => {}
4762            TypeKind::Con(tc) => {
4763                // Builtins are not embeddable ADT declarations.
4764                if tc.builtin_id.is_none() {
4765                    let adt = Type::new(TypeKind::Con(tc.clone()));
4766                    if seen.insert(adt.clone()) {
4767                        out.push(adt.clone());
4768                    }
4769                    let defs = defs_by_name.entry(tc.name.clone()).or_default();
4770                    if !defs.contains(&adt) {
4771                        defs.push(adt);
4772                    }
4773                }
4774            }
4775            TypeKind::App(fun, arg) => {
4776                visit(fun, out, seen, defs_by_name);
4777                visit(arg, out, seen, defs_by_name);
4778            }
4779            TypeKind::Fun(arg, ret) => {
4780                visit(arg, out, seen, defs_by_name);
4781                visit(ret, out, seen, defs_by_name);
4782            }
4783            TypeKind::Tuple(elems) => {
4784                for elem in elems {
4785                    visit(elem, out, seen, defs_by_name);
4786                }
4787            }
4788            TypeKind::Record(fields) => {
4789                for (_name, field_ty) in fields {
4790                    visit(field_ty, out, seen, defs_by_name);
4791                }
4792            }
4793        }
4794    }
4795
4796    let mut out = Vec::new();
4797    let mut seen = HashSet::new();
4798    let mut defs_by_name: BTreeMap<Symbol, Vec<Type>> = BTreeMap::new();
4799    for typ in &types {
4800        visit(typ, &mut out, &mut seen, &mut defs_by_name);
4801    }
4802
4803    let conflicts: Vec<AdtConflict> = defs_by_name
4804        .into_iter()
4805        .filter_map(|(name, definitions)| {
4806            (definitions.len() > 1).then_some(AdtConflict { name, definitions })
4807        })
4808        .collect();
4809    if !conflicts.is_empty() {
4810        return Err(CollectAdtsError { conflicts });
4811    }
4812
4813    Ok(out)
4814}
4815
4816fn adt_name_from_patterns(adts: &HashMap<Symbol, AdtDecl>, patterns: &[Pattern]) -> Option<Symbol> {
4817    let mut candidate: Option<Symbol> = None;
4818    for pat in patterns {
4819        let next = match pat {
4820            Pattern::Named(_, name, _) => {
4821                let name_sym = name.to_dotted_symbol();
4822                ctor_lookup(adts, &name_sym).map(|(adt, _)| adt.name.clone())
4823            }
4824            Pattern::List(..) | Pattern::Cons(..) => Some(sym("List")),
4825            _ => None,
4826        };
4827        if let Some(next) = next {
4828            match &candidate {
4829                None => candidate = Some(next),
4830                Some(prev) if *prev == next => {}
4831                Some(_) => return None,
4832            }
4833        }
4834    }
4835    candidate
4836}
4837
4838fn check_match_exhaustive(
4839    adts: &HashMap<Symbol, AdtDecl>,
4840    scrutinee_ty: &Type,
4841    patterns: &[Pattern],
4842) -> Result<(), TypeError> {
4843    if patterns
4844        .iter()
4845        .any(|p| matches!(p, Pattern::Wildcard(..) | Pattern::Var(_)))
4846    {
4847        return Ok(());
4848    }
4849    let adt_name = match type_head_name(scrutinee_ty).cloned() {
4850        Some(name) => name,
4851        None => match adt_name_from_patterns(adts, patterns) {
4852            Some(name) => name,
4853            None => return Ok(()),
4854        },
4855    };
4856    let adt = match adts.get(&adt_name) {
4857        Some(adt) => adt,
4858        None => return Ok(()),
4859    };
4860    let ctor_names: HashSet<Symbol> = adt.variants.iter().map(|v| v.name.clone()).collect();
4861    if ctor_names.is_empty() {
4862        return Ok(());
4863    }
4864    let mut covered = HashSet::new();
4865    for pat in patterns {
4866        match pat {
4867            Pattern::Named(_, name, _) => {
4868                let name_sym = name.to_dotted_symbol();
4869                if ctor_names.contains(&name_sym) {
4870                    covered.insert(name_sym);
4871                }
4872            }
4873            Pattern::List(_, elems) if adt_name.as_ref() == "List" && elems.is_empty() => {
4874                covered.insert(sym("Empty"));
4875            }
4876            Pattern::Cons(..) if adt_name.as_ref() == "List" => {
4877                covered.insert(sym("Cons"));
4878            }
4879            _ => {}
4880        }
4881    }
4882    let mut missing: Vec<Symbol> = ctor_names.difference(&covered).cloned().collect();
4883    if missing.is_empty() {
4884        return Ok(());
4885    }
4886    missing.sort();
4887    Err(TypeError::NonExhaustiveMatch {
4888        typ: scrutinee_ty.to_string(),
4889        missing,
4890    })
4891}
4892
4893#[cfg(test)]
4894mod tests {
4895    use super::*;
4896    use rexlang_lexer::Token;
4897    use rexlang_parser::Parser;
4898    use rexlang_util::{GasCosts, GasMeter};
4899
4900    fn tvar(id: TypeVarId, name: &str) -> Type {
4901        Type::var(TypeVar::new(id, Some(sym(name))))
4902    }
4903
4904    fn dict_of(elem: Type) -> Type {
4905        Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
4906    }
4907
4908    #[test]
4909    fn unify_simple() {
4910        let t1 = Type::fun(tvar(0, "a"), Type::builtin(BuiltinTypeId::U32));
4911        let t2 = Type::fun(Type::builtin(BuiltinTypeId::U16), tvar(1, "b"));
4912        let subst = unify(&t1, &t2).unwrap();
4913        assert_eq!(subst.get(&0), Some(&Type::builtin(BuiltinTypeId::U16)));
4914        assert_eq!(subst.get(&1), Some(&Type::builtin(BuiltinTypeId::U32)));
4915    }
4916
4917    #[test]
4918    fn occurs_check_blocks_infinite_type() {
4919        let tv = TypeVar::new(0, Some(sym("a")));
4920        let t = Type::fun(Type::var(tv.clone()), Type::builtin(BuiltinTypeId::U8));
4921        let err = bind(&tv, &t).unwrap_err();
4922        assert!(matches!(err, TypeError::Occurs(_, _)));
4923    }
4924
4925    #[test]
4926    fn instantiate_and_generalize_round_trip() {
4927        let mut supply = TypeVarSupply::new();
4928        let a = Type::var(supply.fresh(Some(sym("a"))));
4929        let scheme = generalize(&TypeEnv::new(), vec![], Type::fun(a.clone(), a.clone()));
4930        let (preds, inst) = instantiate(&scheme, &mut supply);
4931        assert!(preds.is_empty());
4932        if let TypeKind::Fun(l, r) = inst.as_ref() {
4933            match (l.as_ref(), r.as_ref()) {
4934                (TypeKind::Var(_), TypeKind::Var(_)) => {}
4935                _ => panic!("expected polymorphic identity"),
4936            }
4937        } else {
4938            panic!("expected function type");
4939        }
4940    }
4941
4942    #[test]
4943    fn entail_superclasses() {
4944        let ts = TypeSystem::with_prelude().unwrap();
4945        let pred = Predicate::new("Semiring", Type::builtin(BuiltinTypeId::I32));
4946        let given = [Predicate::new(
4947            "AdditiveGroup",
4948            Type::builtin(BuiltinTypeId::I32),
4949        )];
4950        assert!(entails(&ts.classes, &given, &pred).unwrap());
4951    }
4952
4953    #[test]
4954    fn entail_instances() {
4955        let ts = TypeSystem::with_prelude().unwrap();
4956        let pred = Predicate::new("Field", Type::builtin(BuiltinTypeId::F32));
4957        assert!(entails(&ts.classes, &[], &pred).unwrap());
4958
4959        let pred_fail = Predicate::new("Field", Type::builtin(BuiltinTypeId::U32));
4960        assert!(!entails(&ts.classes, &[], &pred_fail).unwrap());
4961    }
4962
4963    #[test]
4964    fn prelude_injects_functions() {
4965        let ts = TypeSystem::with_prelude().unwrap();
4966        let minus = ts.env.lookup(&sym("-")).expect("minus in env");
4967        let div = ts.env.lookup(&sym("/")).expect("div in env");
4968        assert_eq!(minus.len(), 1);
4969        assert_eq!(div.len(), 1);
4970        let minus = &minus[0];
4971        let div = &div[0];
4972        assert_eq!(minus.preds.len(), 1);
4973        assert_eq!(minus.vars.len(), 1);
4974        assert_eq!(div.preds.len(), 1);
4975        assert_eq!(div.vars.len(), 1);
4976    }
4977
4978    #[test]
4979    fn adt_constructors_are_present() {
4980        let ts = TypeSystem::with_prelude().unwrap();
4981        assert!(ts.env.lookup(&sym("Empty")).is_some());
4982        assert!(ts.env.lookup(&sym("Cons")).is_some());
4983        assert!(ts.env.lookup(&sym("Ok")).is_some());
4984        assert!(ts.env.lookup(&sym("Err")).is_some());
4985        assert!(ts.env.lookup(&sym("Some")).is_some());
4986        assert!(ts.env.lookup(&sym("None")).is_some());
4987    }
4988
4989    fn parse_expr(code: &str) -> std::sync::Arc<rexlang_ast::expr::Expr> {
4990        let mut parser = Parser::new(Token::tokenize(code).unwrap());
4991        parser.parse_program(&mut GasMeter::default()).unwrap().expr
4992    }
4993
4994    fn parse_program(code: &str) -> rexlang_ast::expr::Program {
4995        let mut parser = Parser::new(Token::tokenize(code).unwrap());
4996        parser.parse_program(&mut GasMeter::default()).unwrap()
4997    }
4998
4999    #[test]
5000    fn infer_deep_list_does_not_overflow() {
5001        // Regression test: moderately deep right-nested terms should infer on default limits.
5002        const N: usize = 40;
5003        let mut code = String::new();
5004        code.push_str("let xs = ");
5005        for _ in 0..N {
5006            code.push_str("Cons 0 (");
5007        }
5008        code.push_str("Empty");
5009        for _ in 0..N {
5010            code.push(')');
5011        }
5012        code.push_str(" in xs");
5013
5014        let parse_handle = std::thread::Builder::new()
5015            .name("infer_deep_list_parse".into())
5016            .stack_size(128 * 1024 * 1024)
5017            .spawn(move || {
5018                let tokens = Token::tokenize(&code).unwrap();
5019                let mut parser = Parser::new(tokens);
5020                parser.parse_program(&mut GasMeter::default())
5021            })
5022            .unwrap();
5023        let program = parse_handle.join().unwrap().unwrap();
5024        let expr = program.expr;
5025        let mut ts = TypeSystem::with_prelude().unwrap();
5026        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5027        assert_eq!(
5028            ty,
5029            Type::app(
5030                Type::builtin(BuiltinTypeId::List),
5031                Type::builtin(BuiltinTypeId::I32)
5032            )
5033        );
5034    }
5035
5036    #[test]
5037    fn collect_adts_in_types_finds_nested_unique_adts() {
5038        let foo = Type::user_con("Foo", 1);
5039        let bar = Type::user_con("Bar", 0);
5040        let ty = Type::fun(
5041            Type::app(
5042                Type::builtin(BuiltinTypeId::List),
5043                Type::app(foo.clone(), tvar(0, "a")),
5044            ),
5045            Type::tuple(vec![
5046                Type::app(foo.clone(), Type::builtin(BuiltinTypeId::I32)),
5047                bar.clone(),
5048            ]),
5049        );
5050
5051        let adts = collect_adts_in_types(vec![ty]).unwrap();
5052        assert_eq!(adts, vec![foo, bar]);
5053    }
5054
5055    #[test]
5056    fn collect_adts_in_types_rejects_conflicting_names() {
5057        let arity1 = Type::user_con("Thing", 1);
5058        let arity2 = Type::user_con("Thing", 2);
5059
5060        let err = collect_adts_in_types(vec![arity1.clone(), arity2.clone()]).unwrap_err();
5061        assert_eq!(err.conflicts.len(), 1);
5062        let conflict = &err.conflicts[0];
5063        assert_eq!(conflict.name, sym("Thing"));
5064        assert_eq!(conflict.definitions, vec![arity1, arity2]);
5065    }
5066
5067    #[test]
5068    fn infer_depth_limit_is_enforced() {
5069        const N: usize = 40;
5070        let mut code = String::new();
5071        code.push_str("let xs = ");
5072        for _ in 0..N {
5073            code.push_str("Cons 0 (");
5074        }
5075        code.push_str("Empty");
5076        for _ in 0..N {
5077            code.push(')');
5078        }
5079        code.push_str(" in xs");
5080
5081        let program = parse_program(&code);
5082        let mut ts = TypeSystem::with_prelude().unwrap();
5083        ts.set_limits(TypeSystemLimits {
5084            max_infer_depth: Some(8),
5085        });
5086
5087        let err = ts.infer(program.expr.as_ref()).unwrap_err();
5088        assert!(
5089            err.to_string().contains("maximum inference depth exceeded"),
5090            "expected a max-depth inference error, got: {err:?}"
5091        );
5092    }
5093
5094    #[test]
5095    fn declare_fn_injects_scheme_for_use_sites() {
5096        let program = parse_program(
5097            r#"
5098            declare fn id x: a -> a
5099            id 1
5100            "#,
5101        );
5102        let mut ts = TypeSystem::with_prelude().unwrap();
5103        ts.inject_decls(&program.decls).unwrap();
5104        let (preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5105        assert!(
5106            preds.is_empty()
5107                || preds.iter().all(|p| p.class.as_ref() == "Integral"
5108                    && p.typ == Type::builtin(BuiltinTypeId::I32))
5109        );
5110        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5111    }
5112
5113    #[test]
5114    fn declare_fn_is_noop_when_matching_existing_scheme() {
5115        let mut ts = TypeSystem::with_prelude().unwrap();
5116        ts.add_value(
5117            "foo",
5118            Scheme::new(
5119                vec![],
5120                vec![],
5121                Type::fun(
5122                    Type::builtin(BuiltinTypeId::I32),
5123                    Type::builtin(BuiltinTypeId::I32),
5124                ),
5125            ),
5126        );
5127
5128        let program = parse_program(
5129            r#"
5130            declare fn foo x: i32 -> i32
5131            0
5132            "#,
5133        );
5134        let rexlang_ast::expr::Decl::DeclareFn(fd) = &program.decls[0] else {
5135            panic!("expected declare fn decl");
5136        };
5137        ts.inject_declare_fn_decl(fd).unwrap();
5138    }
5139
5140    #[test]
5141    fn unit_type_parses_and_infers() {
5142        let program = parse_program(
5143            r#"
5144            fn unit_id x: () -> () = x
5145            unit_id ()
5146            "#,
5147        );
5148        let mut ts = TypeSystem::with_prelude().unwrap();
5149        ts.inject_decls(&program.decls).unwrap();
5150        let (preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5151        assert!(preds.is_empty());
5152        assert_eq!(ty, Type::tuple(vec![]));
5153    }
5154
5155    fn strip_span(mut err: TypeError) -> TypeError {
5156        while let TypeError::Spanned { error, .. } = err {
5157            err = *error;
5158        }
5159        err
5160    }
5161
5162    #[test]
5163    fn type_errors_include_span() {
5164        let expr = parse_expr("missing");
5165        let mut ts = TypeSystem::with_prelude().unwrap();
5166        let err = ts.infer(expr.as_ref()).unwrap_err();
5167        match err {
5168            TypeError::Spanned { span, error } => {
5169                assert_ne!(span, Span::default());
5170                assert!(matches!(
5171                    *error,
5172                    TypeError::UnknownVar(name) if name.as_ref() == "missing"
5173                ));
5174            }
5175            other => panic!("expected spanned error, got {other:?}"),
5176        }
5177    }
5178
5179    #[test]
5180    fn infer_with_gas_rejects_out_of_budget() {
5181        let expr = parse_expr("1");
5182        let mut ts = TypeSystem::with_prelude().unwrap();
5183        let mut gas = GasMeter::new(
5184            Some(0),
5185            GasCosts {
5186                infer_node: 1,
5187                unify_step: 0,
5188                ..GasCosts::sensible_defaults()
5189            },
5190        );
5191        let err = ts.infer_with_gas(expr.as_ref(), &mut gas).unwrap_err();
5192        assert!(matches!(strip_span(err), TypeError::OutOfGas(..)));
5193    }
5194
5195    #[test]
5196    fn reject_user_redefinition_of_primitive_type_name() {
5197        let program = parse_program("type i32 = I32Wrap i32");
5198        let mut ts = TypeSystem::with_prelude().unwrap();
5199        let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
5200            panic!("expected type decl");
5201        };
5202        let err = ts.inject_type_decl(decl).unwrap_err();
5203        assert!(matches!(
5204            err,
5205            TypeError::ReservedTypeName(name) if name.as_ref() == "i32"
5206        ));
5207    }
5208
5209    #[test]
5210    fn reject_user_redefinition_of_prelude_adt_name() {
5211        let program = parse_program("type Result e a = Nope e a");
5212        let mut ts = TypeSystem::with_prelude().unwrap();
5213        let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
5214            panic!("expected type decl");
5215        };
5216        let err = ts.inject_type_decl(decl).unwrap_err();
5217        assert!(matches!(
5218            err,
5219            TypeError::ReservedTypeName(name) if name.as_ref() == "Result"
5220        ));
5221    }
5222
5223    #[test]
5224    fn infer_polymorphic_id_tuple() {
5225        let expr = parse_expr(
5226            r#"
5227            let
5228                id = \x -> x
5229            in
5230                id (id 420, id 6.9, id "str")
5231            "#,
5232        );
5233        let mut ts = TypeSystem::with_prelude().unwrap();
5234        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5235        let expected = Type::tuple(vec![
5236            Type::builtin(BuiltinTypeId::I32),
5237            Type::builtin(BuiltinTypeId::F32),
5238            Type::builtin(BuiltinTypeId::String),
5239        ]);
5240        assert_eq!(ty, expected);
5241    }
5242
5243    #[test]
5244    fn infer_type_annotation_ok() {
5245        let expr = parse_expr("let x: i32 = 42 in x");
5246        let mut ts = TypeSystem::with_prelude().unwrap();
5247        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5248        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5249    }
5250
5251    #[test]
5252    fn infer_type_annotation_lambda_param() {
5253        let expr = parse_expr("\\ (a : f32) -> a");
5254        let mut ts = TypeSystem::with_prelude().unwrap();
5255        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5256        assert_eq!(
5257            ty,
5258            Type::fun(
5259                Type::builtin(BuiltinTypeId::F32),
5260                Type::builtin(BuiltinTypeId::F32)
5261            )
5262        );
5263    }
5264
5265    #[test]
5266    fn infer_type_annotation_is_alias() {
5267        let expr = parse_expr("\"hi\" is str");
5268        let mut ts = TypeSystem::with_prelude().unwrap();
5269        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5270        assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
5271    }
5272
5273    #[test]
5274    fn infer_type_annotation_mismatch_error() {
5275        let expr = parse_expr("let x: i32 = 3.14 in x");
5276        let mut ts = TypeSystem::with_prelude().unwrap();
5277        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5278        assert!(matches!(err, TypeError::Unification(_, _)));
5279    }
5280
5281    #[test]
5282    fn infer_project_single_variant_let() {
5283        let program = parse_program(
5284            r#"
5285            type MyADT = MyVariant1 { field1: i32, field2: f32 }
5286            let
5287                x = MyVariant1 { field1 = 1, field2 = 2.0 }
5288            in
5289                (x.field1, x.field2)
5290            "#,
5291        );
5292        let mut ts = TypeSystem::with_prelude().unwrap();
5293        for decl in &program.decls {
5294            if let rexlang_ast::expr::Decl::Type(decl) = decl {
5295                ts.inject_type_decl(decl).unwrap();
5296            }
5297        }
5298        let (_preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5299        let expected = Type::tuple(vec![
5300            Type::builtin(BuiltinTypeId::I32),
5301            Type::builtin(BuiltinTypeId::F32),
5302        ]);
5303        assert_eq!(ty, expected);
5304    }
5305
5306    #[test]
5307    fn infer_project_known_variant_let() {
5308        let program = parse_program(
5309            r#"
5310            type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
5311            let
5312                x = MyVariant1 { field1 = 1, field2 = 2.0 }
5313            in
5314                x.field1
5315            "#,
5316        );
5317        let mut ts = TypeSystem::with_prelude().unwrap();
5318        for decl in &program.decls {
5319            if let rexlang_ast::expr::Decl::Type(decl) = decl {
5320                ts.inject_type_decl(decl).unwrap();
5321            }
5322        }
5323        let (_preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5324        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5325    }
5326
5327    #[test]
5328    fn infer_project_unknown_variant_error() {
5329        let program = parse_program(
5330            r#"
5331            type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
5332            let
5333                x = MyVariant2 1 2.0
5334            in
5335                x.field1
5336            "#,
5337        );
5338        let mut ts = TypeSystem::with_prelude().unwrap();
5339        for decl in &program.decls {
5340            if let rexlang_ast::expr::Decl::Type(decl) = decl {
5341                ts.inject_type_decl(decl).unwrap();
5342            }
5343        }
5344        let err = strip_span(ts.infer(program.expr.as_ref()).unwrap_err());
5345        assert!(matches!(err, TypeError::FieldNotKnown { .. }));
5346    }
5347
5348    #[test]
5349    fn infer_project_lambda_param_single_variant() {
5350        let program = parse_program(
5351            r#"
5352            type Boxed = Boxed { value: i32 }
5353            let
5354                f = \x -> x.value
5355            in
5356                f (Boxed { value = 1 })
5357            "#,
5358        );
5359        let mut ts = TypeSystem::with_prelude().unwrap();
5360        for decl in &program.decls {
5361            if let rexlang_ast::expr::Decl::Type(decl) = decl {
5362                ts.inject_type_decl(decl).unwrap();
5363            }
5364        }
5365        let (_preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5366        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5367    }
5368
5369    #[test]
5370    fn infer_project_in_match_arm() {
5371        let program = parse_program(
5372            r#"
5373            type MyADT = MyVariant1 { field1: i32 } | MyVariant2 i32
5374            let
5375                x = MyVariant1 { field1 = 1 }
5376            in
5377                match x
5378                    when MyVariant1 { field1 } -> x.field1
5379                    when MyVariant2 _ -> 0
5380            "#,
5381        );
5382        let mut ts = TypeSystem::with_prelude().unwrap();
5383        for decl in &program.decls {
5384            if let rexlang_ast::expr::Decl::Type(decl) = decl {
5385                ts.inject_type_decl(decl).unwrap();
5386            }
5387        }
5388        let (_preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5389        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5390    }
5391
5392    #[test]
5393    fn infer_nested_let_lambda_match_option() {
5394        let expr = parse_expr(
5395            r#"
5396            let
5397                choose = \flag a b -> if flag then a else b,
5398                build = \flag ->
5399                    let
5400                        pick = choose flag,
5401                        val = pick 1 2
5402                    in
5403                        Some val
5404            in
5405                match (build true)
5406                    when Some x -> x
5407                    when None -> 0
5408            "#,
5409        );
5410        let mut ts = TypeSystem::with_prelude().unwrap();
5411        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5412        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5413    }
5414
5415    #[test]
5416    fn infer_polymorphic_apply_in_tuple() {
5417        let expr = parse_expr(
5418            r#"
5419            let
5420                apply = \f x -> f x,
5421                id = \x -> x,
5422                wrap = \x -> (x, x)
5423            in
5424                (apply id 1, apply id "hi", apply wrap true)
5425            "#,
5426        );
5427        let mut ts = TypeSystem::with_prelude().unwrap();
5428        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5429        let expected = Type::tuple(vec![
5430            Type::builtin(BuiltinTypeId::I32),
5431            Type::builtin(BuiltinTypeId::String),
5432            Type::tuple(vec![
5433                Type::builtin(BuiltinTypeId::Bool),
5434                Type::builtin(BuiltinTypeId::Bool),
5435            ]),
5436        ]);
5437        assert_eq!(ty, expected);
5438    }
5439
5440    #[test]
5441    fn infer_nested_result_option_match() {
5442        let expr = parse_expr(
5443            r#"
5444            let
5445                unwrap = \x ->
5446                    match x
5447                        when Ok (Some v) -> v
5448                        when Ok None -> 0
5449                        when Err _ -> 0
5450            in
5451                unwrap (Ok (Some 5))
5452            "#,
5453        );
5454        let mut ts = TypeSystem::with_prelude().unwrap();
5455        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5456        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5457    }
5458
5459    #[test]
5460    fn infer_head_or_list_match() {
5461        let expr = parse_expr(
5462            r#"
5463            let
5464                head_or = \fallback xs ->
5465                    match xs
5466                        when [] -> fallback
5467                        when x::xs -> x
5468            in
5469                (head_or 0 [1, 2, 3], head_or 0 [])
5470            "#,
5471        );
5472        let mut ts = TypeSystem::with_prelude().unwrap();
5473        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5474        let expected = Type::tuple(vec![
5475            Type::builtin(BuiltinTypeId::I32),
5476            Type::builtin(BuiltinTypeId::I32),
5477        ]);
5478        assert_eq!(ty, expected);
5479    }
5480
5481    #[test]
5482    fn infer_head_or_list_match_cons_constructor_form() {
5483        let expr = parse_expr(
5484            r#"
5485            let
5486                head_or = \fallback xs ->
5487                    match xs
5488                        when [] -> fallback
5489                        when Cons x xs1 -> x
5490            in
5491                (head_or 0 (Cons 1 (Cons 2 Empty)), head_or 0 Empty)
5492            "#,
5493        );
5494        let mut ts = TypeSystem::with_prelude().unwrap();
5495        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5496        let expected = Type::tuple(vec![
5497            Type::builtin(BuiltinTypeId::I32),
5498            Type::builtin(BuiltinTypeId::I32),
5499        ]);
5500        assert_eq!(ty, expected);
5501    }
5502
5503    #[test]
5504    fn infer_record_pattern_in_lambda() {
5505        let program = parse_program(
5506            r#"
5507            type Pair = Pair { left: i32, right: i32 }
5508            let
5509                sum = \p ->
5510                    match p
5511                        when Pair { left, right } -> left + right
5512            in
5513                sum (Pair { left = 1, right = 2 })
5514            "#,
5515        );
5516        let mut ts = TypeSystem::with_prelude().unwrap();
5517        for decl in &program.decls {
5518            if let rexlang_ast::expr::Decl::Type(decl) = decl {
5519                ts.inject_type_decl(decl).unwrap();
5520            }
5521        }
5522        let (_preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5523        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5524    }
5525
5526    #[test]
5527    fn infer_fn_decl_simple() {
5528        let program = parse_program(
5529            r#"
5530            fn add (x: i32, y: i32) -> i32 = x + y
5531            add 1 2
5532            "#,
5533        );
5534        let mut ts = TypeSystem::with_prelude().unwrap();
5535        let expr = program.expr_with_fns();
5536        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5537        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5538    }
5539
5540    #[test]
5541    fn infer_fn_decl_signature_form() {
5542        let program = parse_program(
5543            r#"
5544            fn add : i32 -> i32 -> i32 = \x y -> x + y
5545            add 1 2
5546            "#,
5547        );
5548        let mut ts = TypeSystem::with_prelude().unwrap();
5549        let expr = program.expr_with_fns();
5550        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5551        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5552    }
5553
5554    #[test]
5555    fn infer_fn_decl_polymorphic_where_constraints() {
5556        let program = parse_program(
5557            r#"
5558            fn my_add (x: a, y: a) -> a where AdditiveMonoid a = x + y
5559            (my_add 1 2, my_add 1.0 2.0)
5560            "#,
5561        );
5562        let mut ts = TypeSystem::with_prelude().unwrap();
5563        let expr = program.expr_with_fns();
5564        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5565        assert_eq!(
5566            ty,
5567            Type::tuple(vec![
5568                Type::builtin(BuiltinTypeId::I32),
5569                Type::builtin(BuiltinTypeId::F32)
5570            ])
5571        );
5572    }
5573
5574    #[test]
5575    fn infer_additive_monoid_constraint() {
5576        let expr = parse_expr("\\x y -> x + y");
5577        let mut ts = TypeSystem::with_prelude().unwrap();
5578        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5579        assert_eq!(preds.len(), 1);
5580        assert_eq!(preds[0].class.as_ref(), "AdditiveMonoid");
5581
5582        if let TypeKind::Fun(a, rest) = ty.as_ref()
5583            && let TypeKind::Fun(b, c) = rest.as_ref()
5584        {
5585            assert_eq!(a.as_ref(), b.as_ref());
5586            assert_eq!(b.as_ref(), c.as_ref());
5587            assert_eq!(preds[0].typ, a.clone());
5588            return;
5589        }
5590        panic!("expected a -> a -> a");
5591    }
5592
5593    #[test]
5594    fn infer_multiplicative_monoid_constraint() {
5595        let expr = parse_expr("\\x y -> x * y");
5596        let mut ts = TypeSystem::with_prelude().unwrap();
5597        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5598        assert_eq!(preds.len(), 1);
5599        assert_eq!(preds[0].class.as_ref(), "MultiplicativeMonoid");
5600
5601        if let TypeKind::Fun(a, rest) = ty.as_ref()
5602            && let TypeKind::Fun(b, c) = rest.as_ref()
5603        {
5604            assert_eq!(a.as_ref(), b.as_ref());
5605            assert_eq!(b.as_ref(), c.as_ref());
5606            assert_eq!(preds[0].typ, a.clone());
5607            return;
5608        }
5609        panic!("expected a -> a -> a");
5610    }
5611
5612    #[test]
5613    fn infer_additive_group_constraint() {
5614        let expr = parse_expr("\\x y -> x - y");
5615        let mut ts = TypeSystem::with_prelude().unwrap();
5616        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5617        assert_eq!(preds.len(), 1);
5618        assert_eq!(preds[0].class.as_ref(), "AdditiveGroup");
5619
5620        if let TypeKind::Fun(a, rest) = ty.as_ref()
5621            && let TypeKind::Fun(b, c) = rest.as_ref()
5622        {
5623            assert_eq!(a.as_ref(), b.as_ref());
5624            assert_eq!(b.as_ref(), c.as_ref());
5625            assert_eq!(preds[0].typ, a.clone());
5626            return;
5627        }
5628        panic!("expected a -> a -> a");
5629    }
5630
5631    #[test]
5632    fn infer_integral_constraint() {
5633        let expr = parse_expr("\\x y -> x % y");
5634        let mut ts = TypeSystem::with_prelude().unwrap();
5635        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5636        assert_eq!(preds.len(), 1);
5637        assert_eq!(preds[0].class.as_ref(), "Integral");
5638
5639        if let TypeKind::Fun(a, rest) = ty.as_ref()
5640            && let TypeKind::Fun(b, c) = rest.as_ref()
5641        {
5642            assert_eq!(a.as_ref(), b.as_ref());
5643            assert_eq!(b.as_ref(), c.as_ref());
5644            assert_eq!(preds[0].typ, a.clone());
5645            return;
5646        }
5647        panic!("expected a -> a -> a");
5648    }
5649
5650    #[test]
5651    fn infer_literal_addition_defaults() {
5652        let expr = parse_expr("1 + 2");
5653        let mut ts = TypeSystem::with_prelude().unwrap();
5654        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5655        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5656        assert_eq!(preds.len(), 2);
5657        assert!(preds.iter().any(|p| p.class.as_ref() == "AdditiveMonoid"));
5658        assert!(preds.iter().any(|p| p.class.as_ref() == "Integral"));
5659        assert!(
5660            preds
5661                .iter()
5662                .all(|p| p.typ == Type::builtin(BuiltinTypeId::I32))
5663        );
5664    }
5665
5666    #[test]
5667    fn infer_mod_defaults() {
5668        let expr = parse_expr("1 % 2");
5669        let mut ts = TypeSystem::with_prelude().unwrap();
5670        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5671        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5672        assert_eq!(preds.len(), 1);
5673        assert_eq!(preds[0].class.as_ref(), "Integral");
5674        assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::I32));
5675    }
5676
5677    #[test]
5678    fn infer_get_list_type() {
5679        let expr = parse_expr("get 1 [1, 2, 3]");
5680        let mut ts = TypeSystem::with_prelude().unwrap();
5681        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5682        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5683        assert!(preds.iter().any(|p| p.class.as_ref() == "Indexable"));
5684        assert!(preds.iter().all(|p| {
5685            p.class.as_ref() == "Indexable"
5686                || (p.class.as_ref() == "Integral" && p.typ == Type::builtin(BuiltinTypeId::I32))
5687        }));
5688        for pred in preds.iter().filter(|p| p.class.as_ref() == "Indexable") {
5689            assert!(entails(&ts.classes, &[], pred).unwrap());
5690        }
5691    }
5692
5693    #[test]
5694    fn infer_get_tuple_type() {
5695        let expr = parse_expr("(1, 'Hello', true).0");
5696        let mut ts = TypeSystem::with_prelude().unwrap();
5697        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5698        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5699        assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
5700
5701        let expr = parse_expr("(1, 'Hello', true).1");
5702        let mut ts = TypeSystem::with_prelude().unwrap();
5703        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5704        assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
5705        assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
5706
5707        let expr = parse_expr("(1, 'Hello', true).2");
5708        let mut ts = TypeSystem::with_prelude().unwrap();
5709        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5710        assert_eq!(ty, Type::builtin(BuiltinTypeId::Bool));
5711        assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
5712    }
5713
5714    #[test]
5715    fn infer_division_defaults() {
5716        let expr = parse_expr("1.0 / 2.0");
5717        let mut ts = TypeSystem::with_prelude().unwrap();
5718        let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5719        assert_eq!(ty, Type::builtin(BuiltinTypeId::F32));
5720        assert_eq!(preds.len(), 1);
5721        assert_eq!(preds[0].class.as_ref(), "Field");
5722        assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::F32));
5723        assert!(entails(&ts.classes, &[], &preds[0]).unwrap());
5724    }
5725
5726    #[test]
5727    fn infer_unbound_variable_error() {
5728        let expr = parse_expr("missing");
5729        let mut ts = TypeSystem::with_prelude().unwrap();
5730        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5731        assert!(matches!(
5732            err,
5733            TypeError::UnknownVar(name) if name.as_ref() == "missing"
5734        ));
5735    }
5736
5737    #[test]
5738    fn infer_if_branch_type_mismatch_error() {
5739        let expr = parse_expr(r#"if true then 1 else "no""#);
5740        let mut ts = TypeSystem::with_prelude().unwrap();
5741        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5742        match err {
5743            TypeError::Unification(a, b) => {
5744                let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
5745                assert!(ok, "expected i32 vs string, got {a} vs {b}");
5746            }
5747            other => panic!("expected unification error, got {other:?}"),
5748        }
5749    }
5750
5751    #[test]
5752    fn infer_unknown_pattern_constructor_error() {
5753        let expr = parse_expr("match 1 when Nope -> 1");
5754        let mut ts = TypeSystem::with_prelude().unwrap();
5755        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5756        assert!(matches!(
5757            err,
5758            TypeError::UnknownVar(name) if name.as_ref() == "Nope"
5759        ));
5760    }
5761
5762    #[test]
5763    fn infer_ambiguous_overload_error() {
5764        let mut ts = TypeSystem::new();
5765        let a = TypeVar::new(0, Some(sym("a")));
5766        let b = TypeVar::new(1, Some(sym("b")));
5767        let scheme_a = Scheme::new(vec![a.clone()], vec![], Type::var(a));
5768        let scheme_b = Scheme::new(vec![b.clone()], vec![], Type::var(b));
5769        ts.add_overload(sym("dup"), scheme_a);
5770        ts.add_overload(sym("dup"), scheme_b);
5771        let expr = parse_expr("dup");
5772        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5773        assert!(matches!(
5774            err,
5775            TypeError::AmbiguousOverload(name) if name.as_ref() == "dup"
5776        ));
5777    }
5778
5779    #[test]
5780    fn infer_if_cond_not_bool_error() {
5781        let expr = parse_expr("if 1 then 2 else 3");
5782        let mut ts = TypeSystem::with_prelude().unwrap();
5783        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5784        match err {
5785            TypeError::Unification(a, b) => {
5786                let ok = (a == "bool" && b == "i32") || (a == "i32" && b == "bool");
5787                assert!(ok, "expected bool vs i32, got {a} vs {b}");
5788            }
5789            other => panic!("expected unification error, got {other:?}"),
5790        }
5791    }
5792
5793    #[test]
5794    fn infer_apply_non_function_error() {
5795        let expr = parse_expr("1 2");
5796        let mut ts = TypeSystem::with_prelude().unwrap();
5797        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5798        assert!(matches!(err, TypeError::Unification(_, _)));
5799    }
5800
5801    #[test]
5802    fn infer_list_element_mismatch_error() {
5803        let expr = parse_expr("[1, true]");
5804        let mut ts = TypeSystem::with_prelude().unwrap();
5805        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5806        match err {
5807            TypeError::Unification(a, b) => {
5808                let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
5809                assert!(ok, "expected i32 vs bool, got {a} vs {b}");
5810            }
5811            other => panic!("expected unification error, got {other:?}"),
5812        }
5813    }
5814
5815    #[test]
5816    fn infer_dict_value_mismatch_error() {
5817        let expr = parse_expr("{a = 1, b = true}");
5818        let mut ts = TypeSystem::with_prelude().unwrap();
5819        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5820        match err {
5821            TypeError::Unification(a, b) => {
5822                let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
5823                assert!(ok, "expected i32 vs bool, got {a} vs {b}");
5824            }
5825            other => panic!("expected unification error, got {other:?}"),
5826        }
5827    }
5828
5829    #[test]
5830    fn infer_match_list_on_non_list_error() {
5831        let expr = parse_expr("match 1 when [x] -> x");
5832        let mut ts = TypeSystem::with_prelude().unwrap();
5833        assert!(ts.infer(expr.as_ref()).is_err());
5834    }
5835
5836    #[test]
5837    fn infer_pattern_constructor_arity_error() {
5838        let expr = parse_expr("match (Ok 1) when Ok x y -> x");
5839        let mut ts = TypeSystem::with_prelude().unwrap();
5840        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5841        assert!(matches!(
5842            err,
5843            TypeError::UnsupportedExpr("pattern constructor")
5844        ));
5845    }
5846
5847    #[test]
5848    fn infer_match_arm_type_mismatch_error() {
5849        let expr = parse_expr(r#"match 1 when _ -> 1 when _ -> "no""#);
5850        let mut ts = TypeSystem::with_prelude().unwrap();
5851        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5852        match err {
5853            TypeError::Unification(a, b) => {
5854                let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
5855                assert!(ok, "expected i32 vs string, got {a} vs {b}");
5856            }
5857            other => panic!("expected unification error, got {other:?}"),
5858        }
5859    }
5860
5861    #[test]
5862    fn infer_match_option_on_non_option_error() {
5863        let expr = parse_expr("match 1 when Some x -> x");
5864        let mut ts = TypeSystem::with_prelude().unwrap();
5865        assert!(ts.infer(expr.as_ref()).is_err());
5866    }
5867
5868    #[test]
5869    fn infer_dict_pattern_on_non_dict_error() {
5870        let expr = parse_expr("match 1 when {a} -> a");
5871        let mut ts = TypeSystem::with_prelude().unwrap();
5872        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5873        assert!(matches!(err, TypeError::Unification(_, _)));
5874    }
5875
5876    #[test]
5877    fn infer_cons_pattern_on_non_list_error() {
5878        let expr = parse_expr("match 1 when x::xs -> x");
5879        let mut ts = TypeSystem::with_prelude().unwrap();
5880        assert!(ts.infer(expr.as_ref()).is_err());
5881    }
5882
5883    #[test]
5884    fn infer_apply_wrong_arg_type_error() {
5885        let expr = parse_expr("(\\x -> x + 1) true");
5886        let mut ts = TypeSystem::with_prelude().unwrap();
5887        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5888        assert!(matches!(err, TypeError::Unification(_, _)));
5889    }
5890
5891    #[test]
5892    fn infer_self_application_occurs_error() {
5893        let expr = parse_expr("\\x -> x x");
5894        let mut ts = TypeSystem::with_prelude().unwrap();
5895        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5896        assert!(matches!(err, TypeError::Occurs(_, _)));
5897    }
5898
5899    #[test]
5900    fn infer_apply_constructor_too_many_args_error() {
5901        let expr = parse_expr("Some 1 2");
5902        let mut ts = TypeSystem::with_prelude().unwrap();
5903        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5904        assert!(matches!(err, TypeError::Unification(_, _)));
5905    }
5906
5907    #[test]
5908    fn infer_operator_type_mismatch_error() {
5909        let expr = parse_expr("1 + true");
5910        let mut ts = TypeSystem::with_prelude().unwrap();
5911        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5912        assert!(matches!(err, TypeError::Unification(_, _)));
5913    }
5914
5915    #[test]
5916    fn infer_non_exhaustive_match_is_error() {
5917        let expr = parse_expr("match (Ok 1) when Ok x -> x");
5918        let mut ts = TypeSystem::with_prelude().unwrap();
5919        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5920        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
5921    }
5922
5923    #[test]
5924    fn infer_non_exhaustive_match_on_bound_var_error() {
5925        let expr = parse_expr("let x = Ok 1 in match x when Ok y -> y");
5926        let mut ts = TypeSystem::with_prelude().unwrap();
5927        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5928        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
5929    }
5930
5931    #[test]
5932    fn infer_non_exhaustive_match_in_lambda_error() {
5933        let expr = parse_expr("\\x -> match x when Ok y -> y");
5934        let mut ts = TypeSystem::with_prelude().unwrap();
5935        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5936        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
5937    }
5938
5939    #[test]
5940    fn infer_non_exhaustive_option_match_error() {
5941        let expr = parse_expr("match (Some 1) when Some x -> x");
5942        let mut ts = TypeSystem::with_prelude().unwrap();
5943        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5944        match err {
5945            TypeError::NonExhaustiveMatch { missing, .. } => {
5946                assert_eq!(missing, vec![sym("None")]);
5947            }
5948            other => panic!("expected non-exhaustive match, got {other:?}"),
5949        }
5950    }
5951
5952    #[test]
5953    fn infer_non_exhaustive_result_match_error() {
5954        let expr = parse_expr("match (Err 1) when Ok x -> x");
5955        let mut ts = TypeSystem::with_prelude().unwrap();
5956        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5957        match err {
5958            TypeError::NonExhaustiveMatch { missing, .. } => {
5959                assert_eq!(missing, vec![sym("Err")]);
5960            }
5961            other => panic!("expected non-exhaustive match, got {other:?}"),
5962        }
5963    }
5964
5965    #[test]
5966    fn infer_non_exhaustive_list_missing_empty_error() {
5967        let expr = parse_expr("match [1, 2] when x::xs -> x");
5968        let mut ts = TypeSystem::with_prelude().unwrap();
5969        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5970        match err {
5971            TypeError::NonExhaustiveMatch { missing, .. } => {
5972                assert_eq!(missing, vec![sym("Empty")]);
5973            }
5974            other => panic!("expected non-exhaustive match, got {other:?}"),
5975        }
5976    }
5977
5978    #[test]
5979    fn infer_non_exhaustive_list_match_on_bound_var_error() {
5980        let expr = parse_expr("let xs = [1, 2] in match xs when x::xs -> x");
5981        let mut ts = TypeSystem::with_prelude().unwrap();
5982        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5983        assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
5984    }
5985
5986    #[test]
5987    fn infer_non_exhaustive_list_missing_cons_error() {
5988        let expr = parse_expr("match [1] when [] -> 0");
5989        let mut ts = TypeSystem::with_prelude().unwrap();
5990        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5991        match err {
5992            TypeError::NonExhaustiveMatch { missing, .. } => {
5993                assert_eq!(missing, vec![sym("Cons")]);
5994            }
5995            other => panic!("expected non-exhaustive match, got {other:?}"),
5996        }
5997    }
5998
5999    #[test]
6000    fn infer_match_list_patterns_on_result_error() {
6001        let expr = parse_expr("match (Ok 1) when [] -> 0 when x::xs -> 1");
6002        let mut ts = TypeSystem::with_prelude().unwrap();
6003        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
6004        assert!(matches!(err, TypeError::Unification(_, _)));
6005    }
6006
6007    #[test]
6008    fn infer_missing_instances_produce_unsatisfied_predicates() {
6009        for (name, code) in [
6010            ("division", "1 / 2"),
6011            ("eq_dict", "{a = 1} == {a = 2}"),
6012            ("min_bool", "min [true]"),
6013            ("map_dict", r#"map (\x -> x) {a = 1}"#),
6014        ] {
6015            let (class, pred_type, expected_ty) = match name {
6016                "division" => (
6017                    "Field",
6018                    Type::builtin(BuiltinTypeId::I32),
6019                    Some(Type::builtin(BuiltinTypeId::I32)),
6020                ),
6021                "eq_dict" => ("Eq", dict_of(Type::builtin(BuiltinTypeId::I32)), None),
6022                "min_bool" => ("Ord", Type::builtin(BuiltinTypeId::Bool), None),
6023                "map_dict" => ("Functor", Type::builtin(BuiltinTypeId::Dict), None),
6024                _ => unreachable!("unknown test case {name}"),
6025            };
6026
6027            let expr = parse_expr(code);
6028            let mut ts = TypeSystem::with_prelude().unwrap();
6029            let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
6030            if let Some(expected) = expected_ty {
6031                assert_eq!(ty, expected, "{name}");
6032            }
6033
6034            let pred = preds
6035                .iter()
6036                .find(|p| p.class.as_ref() == class && p.typ == pred_type)
6037                .unwrap();
6038            assert!(!entails(&ts.classes, &[], pred).unwrap(), "{name}");
6039        }
6040    }
6041
6042    #[test]
6043    fn record_update_single_variant_adt_infers() {
6044        let program = parse_program(
6045            r#"
6046            type Foo = Bar { x: i32, y: i32 }
6047            let
6048              foo: Foo = Bar { x = 1, y = 2 },
6049              bar = { foo with { x = 3 } }
6050            in
6051              bar
6052            "#,
6053        );
6054        let mut ts = TypeSystem::with_prelude().unwrap();
6055        ts.inject_decls(&program.decls).unwrap();
6056        let (_preds, typ) = ts.infer(program.expr.as_ref()).unwrap();
6057        assert_eq!(typ.to_string(), "Foo");
6058    }
6059
6060    #[test]
6061    fn record_update_unknown_field_errors() {
6062        let program = parse_program(
6063            r#"
6064            type Foo = Bar { x: i32 }
6065            let
6066              foo: Foo = Bar { x = 1 }
6067            in
6068              { foo with { y = 2 } }
6069            "#,
6070        );
6071        let mut ts = TypeSystem::with_prelude().unwrap();
6072        ts.inject_decls(&program.decls).unwrap();
6073        let err = ts.infer(program.expr.as_ref()).unwrap_err();
6074        let err = strip_span(err);
6075        assert!(matches!(err, TypeError::UnknownField { .. }));
6076    }
6077
6078    #[test]
6079    fn record_update_requires_refined_variant_for_sum_types() {
6080        let program = parse_program(
6081            r#"
6082            type Foo = Bar { x: i32 } | Baz { x: i32 }
6083            let
6084              f = \ (foo : Foo) -> { foo with { x = 2 } }
6085            in
6086              f (Bar { x = 1 })
6087            "#,
6088        );
6089        let mut ts = TypeSystem::with_prelude().unwrap();
6090        ts.inject_decls(&program.decls).unwrap();
6091        let err = ts.infer(program.expr.as_ref()).unwrap_err();
6092        let err = strip_span(err);
6093        assert!(matches!(err, TypeError::FieldNotKnown { .. }));
6094    }
6095
6096    #[test]
6097    fn record_update_allowed_after_match_refines_variant() {
6098        let program = parse_program(
6099            r#"
6100            type Foo = Bar { x: i32 } | Baz { x: i32 }
6101            let
6102              f = \ (foo : Foo) ->
6103                match foo
6104                  when Bar {x} -> { foo with { x = x + 1 } }
6105                  when Baz {x} -> { foo with { x = x + 2 } }
6106            in
6107              f (Bar { x = 1 })
6108            "#,
6109        );
6110        let mut ts = TypeSystem::with_prelude().unwrap();
6111        ts.inject_decls(&program.decls).unwrap();
6112        let (_preds, typ) = ts.infer(program.expr.as_ref()).unwrap();
6113        assert_eq!(typ.to_string(), "Foo");
6114    }
6115
6116    #[test]
6117    fn record_update_plain_record_type() {
6118        let program = parse_program(
6119            r#"
6120            let
6121              f = \ (r : { x: i32, y: i32 }) -> { r with { y = 9 } }
6122            in
6123              f { x = 1, y = 2 }
6124            "#,
6125        );
6126        let mut ts = TypeSystem::with_prelude().unwrap();
6127        ts.inject_decls(&program.decls).unwrap();
6128        let (_preds, typ) = ts.infer(program.expr.as_ref()).unwrap();
6129        assert_eq!(typ.to_string(), "{x: i32, y: i32}");
6130    }
6131
6132    #[test]
6133    fn infer_typed_hole_expr_is_hole_kind() {
6134        let expr = parse_expr("?");
6135        let mut ts = TypeSystem::with_prelude().unwrap();
6136        let (typed, _preds, _ty) = ts.infer_typed(expr.as_ref()).unwrap();
6137        assert!(
6138            matches!(typed.kind, TypedExprKind::Hole),
6139            "typed={typed:#?}"
6140        );
6141    }
6142
6143    #[test]
6144    fn infer_hole_with_annotation_unifies_to_annotation() {
6145        let expr = parse_expr("let x : i32 = ? in x");
6146        let mut ts = TypeSystem::with_prelude().unwrap();
6147        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
6148        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
6149    }
6150
6151    #[test]
6152    fn infer_hole_in_if_condition_is_bool_constrained() {
6153        let expr = parse_expr("if ? then 1 else 2");
6154        let mut ts = TypeSystem::with_prelude().unwrap();
6155        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
6156        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
6157    }
6158
6159    #[test]
6160    fn infer_hole_in_arithmetic_is_numeric_constrained() {
6161        let expr = parse_expr("? + 1");
6162        let mut ts = TypeSystem::with_prelude().unwrap();
6163        let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
6164        assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
6165    }
6166
6167    #[test]
6168    fn infer_hole_arithmetic_conflicting_annotation_failure() {
6169        let expr = parse_expr("let x : string = (? + 1) in x");
6170        let mut ts = TypeSystem::with_prelude().unwrap();
6171        let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
6172        assert!(matches!(err, TypeError::Unification(_, _)), "err={err:#?}");
6173    }
6174}