Skip to main content

rexlang_typesystem/
typesystem.rs

1//! Core type system implementation for Rex.
2
3use std::collections::{BTreeMap, HashMap, HashSet};
4use std::fmt::{self, Display, Formatter};
5use std::sync::Arc;
6
7use chrono::{DateTime, Utc};
8use rexlang_ast::expr::{
9    ClassDecl, ClassMethodSig, Decl, DeclareFnDecl, Expr, FnDecl, InstanceDecl, InstanceMethodImpl,
10    Pattern, Scope, Symbol, TypeConstraint, TypeDecl, TypeExpr, intern, sym,
11};
12use rexlang_lexer::span::Span;
13use rexlang_util::{GasMeter, OutOfGas};
14use rpds::HashTrieMapSync;
15use uuid::Uuid;
16
17use crate::prelude;
18
19#[path = "inference.rs"]
20pub mod inference;
21
22pub use inference::{infer, infer_typed, infer_typed_with_gas, infer_with_gas};
23
24pub type TypeVarId = usize;
25
26#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
27pub enum BuiltinTypeId {
28    U8,
29    U16,
30    U32,
31    U64,
32    I8,
33    I16,
34    I32,
35    I64,
36    F32,
37    F64,
38    Bool,
39    String,
40    Uuid,
41    DateTime,
42    List,
43    Array,
44    Dict,
45    Option,
46    Result,
47}
48
49impl BuiltinTypeId {
50    pub fn as_symbol(self) -> Symbol {
51        sym(self.as_str())
52    }
53
54    pub fn as_str(self) -> &'static str {
55        match self {
56            Self::U8 => "u8",
57            Self::U16 => "u16",
58            Self::U32 => "u32",
59            Self::U64 => "u64",
60            Self::I8 => "i8",
61            Self::I16 => "i16",
62            Self::I32 => "i32",
63            Self::I64 => "i64",
64            Self::F32 => "f32",
65            Self::F64 => "f64",
66            Self::Bool => "bool",
67            Self::String => "string",
68            Self::Uuid => "uuid",
69            Self::DateTime => "datetime",
70            Self::List => "List",
71            Self::Array => "Array",
72            Self::Dict => "Dict",
73            Self::Option => "Option",
74            Self::Result => "Result",
75        }
76    }
77
78    pub fn arity(self) -> usize {
79        match self {
80            Self::List | Self::Array | Self::Dict | Self::Option => 1,
81            Self::Result => 2,
82            _ => 0,
83        }
84    }
85
86    pub fn from_symbol(name: &Symbol) -> Option<Self> {
87        Self::from_name(name.as_ref())
88    }
89
90    pub fn from_name(name: &str) -> Option<Self> {
91        match name {
92            "u8" => Some(Self::U8),
93            "u16" => Some(Self::U16),
94            "u32" => Some(Self::U32),
95            "u64" => Some(Self::U64),
96            "i8" => Some(Self::I8),
97            "i16" => Some(Self::I16),
98            "i32" => Some(Self::I32),
99            "i64" => Some(Self::I64),
100            "f32" => Some(Self::F32),
101            "f64" => Some(Self::F64),
102            "bool" => Some(Self::Bool),
103            "string" => Some(Self::String),
104            "uuid" => Some(Self::Uuid),
105            "datetime" => Some(Self::DateTime),
106            "List" => Some(Self::List),
107            "Array" => Some(Self::Array),
108            "Dict" => Some(Self::Dict),
109            "Option" => Some(Self::Option),
110            "Result" => Some(Self::Result),
111            _ => None,
112        }
113    }
114}
115
116#[derive(Clone, Debug, Eq, Hash, PartialEq)]
117pub struct TypeVar {
118    pub id: TypeVarId,
119    pub name: Option<Symbol>,
120}
121
122impl TypeVar {
123    pub fn new(id: TypeVarId, name: impl Into<Option<Symbol>>) -> Self {
124        Self {
125            id,
126            name: name.into(),
127        }
128    }
129}
130
131#[derive(Clone, Debug, Eq, Hash, PartialEq)]
132pub struct TypeConst {
133    pub name: Symbol,
134    pub arity: usize,
135    pub builtin_id: Option<BuiltinTypeId>,
136}
137
138#[derive(Clone, Debug, PartialEq, Eq, Hash)]
139pub struct Type(Arc<TypeKind>);
140
141#[derive(Clone, Debug, PartialEq, Eq, Hash)]
142pub enum TypeKind {
143    Var(TypeVar),
144    Con(TypeConst),
145    App(Type, Type),
146    Fun(Type, Type),
147    Tuple(Vec<Type>),
148    /// Record type `{a: T, b: U}`.
149    ///
150    /// Invariant: fields are sorted by name. This makes record equality and
151    /// unification a cheap zip over two vectors, and it makes printing stable.
152    Record(Vec<(Symbol, Type)>),
153}
154
155impl Type {
156    pub fn new(kind: TypeKind) -> Self {
157        Type(Arc::new(kind))
158    }
159
160    pub fn con(name: impl AsRef<str>, arity: usize) -> Self {
161        if let Some(id) = BuiltinTypeId::from_name(name.as_ref())
162            && id.arity() == arity
163        {
164            return Self::builtin(id);
165        }
166        Self::user_con(name, arity)
167    }
168
169    pub fn user_con(name: impl AsRef<str>, arity: usize) -> Self {
170        Type::new(TypeKind::Con(TypeConst {
171            name: intern(name.as_ref()),
172            arity,
173            builtin_id: None,
174        }))
175    }
176
177    pub fn builtin(id: BuiltinTypeId) -> Self {
178        Type::new(TypeKind::Con(TypeConst {
179            name: id.as_symbol(),
180            arity: id.arity(),
181            builtin_id: Some(id),
182        }))
183    }
184
185    pub fn var(tv: TypeVar) -> Self {
186        Type::new(TypeKind::Var(tv))
187    }
188
189    pub fn fun(a: Type, b: Type) -> Self {
190        Type::new(TypeKind::Fun(a, b))
191    }
192
193    pub fn app(f: Type, arg: Type) -> Self {
194        Type::new(TypeKind::App(f, arg))
195    }
196
197    pub fn tuple(elems: Vec<Type>) -> Self {
198        Type::new(TypeKind::Tuple(elems))
199    }
200
201    pub fn record(mut fields: Vec<(Symbol, Type)>) -> Self {
202        // Canonicalize records so downstream code can rely on “same shape means
203        // same ordering”. (This is a correctness invariant, not a nicety.)
204        fields.sort_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
205        Type::new(TypeKind::Record(fields))
206    }
207
208    pub fn list(elem: Type) -> Type {
209        Type::app(Type::builtin(BuiltinTypeId::List), elem)
210    }
211
212    pub fn array(elem: Type) -> Type {
213        Type::app(Type::builtin(BuiltinTypeId::Array), elem)
214    }
215
216    pub fn dict(elem: Type) -> Type {
217        Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
218    }
219
220    pub fn option(elem: Type) -> Type {
221        Type::app(Type::builtin(BuiltinTypeId::Option), elem)
222    }
223
224    pub fn result(ok: Type, err: Type) -> Type {
225        Type::app(Type::app(Type::builtin(BuiltinTypeId::Result), err), ok)
226    }
227
228    fn apply_with_change(&self, s: &Subst) -> (Type, bool) {
229        match self.as_ref() {
230            TypeKind::Var(tv) => match s.get(&tv.id) {
231                Some(ty) => (ty.clone(), true),
232                None => (self.clone(), false),
233            },
234            TypeKind::Con(_) => (self.clone(), false),
235            TypeKind::App(l, r) => {
236                let (l_new, l_changed) = l.apply_with_change(s);
237                let (r_new, r_changed) = r.apply_with_change(s);
238                if l_changed || r_changed {
239                    (Type::app(l_new, r_new), true)
240                } else {
241                    (self.clone(), false)
242                }
243            }
244            TypeKind::Fun(_, _) => {
245                // Avoid recursive descent on long function chains like
246                // `a1 -> a2 -> ... -> an -> r`.
247                let mut args = Vec::new();
248                let mut changed = false;
249                let mut cur: &Type = self;
250                while let TypeKind::Fun(a, b) = cur.as_ref() {
251                    let (a_new, a_changed) = a.apply_with_change(s);
252                    changed |= a_changed;
253                    args.push(a_new);
254                    cur = b;
255                }
256                let (ret_new, ret_changed) = cur.apply_with_change(s);
257                changed |= ret_changed;
258                if !changed {
259                    return (self.clone(), false);
260                }
261                let mut out = ret_new;
262                for a_new in args.into_iter().rev() {
263                    out = Type::fun(a_new, out);
264                }
265                (out, true)
266            }
267            TypeKind::Tuple(ts) => {
268                let mut changed = false;
269                let mut out = Vec::with_capacity(ts.len());
270                for t in ts {
271                    let (t_new, t_changed) = t.apply_with_change(s);
272                    changed |= t_changed;
273                    out.push(t_new);
274                }
275                if changed {
276                    (Type::new(TypeKind::Tuple(out)), true)
277                } else {
278                    (self.clone(), false)
279                }
280            }
281            TypeKind::Record(fields) => {
282                let mut changed = false;
283                let mut out = Vec::with_capacity(fields.len());
284                for (k, v) in fields {
285                    let (v_new, v_changed) = v.apply_with_change(s);
286                    changed |= v_changed;
287                    out.push((k.clone(), v_new));
288                }
289                if changed {
290                    (Type::new(TypeKind::Record(out)), true)
291                } else {
292                    (self.clone(), false)
293                }
294            }
295        }
296    }
297}
298
299impl AsRef<TypeKind> for Type {
300    fn as_ref(&self) -> &TypeKind {
301        self.0.as_ref()
302    }
303}
304
305impl std::ops::Deref for Type {
306    type Target = TypeKind;
307
308    fn deref(&self) -> &Self::Target {
309        &self.0
310    }
311}
312
313impl Display for Type {
314    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
315        match self.as_ref() {
316            TypeKind::Var(tv) => match &tv.name {
317                Some(name) => write!(f, "'{}", name),
318                None => write!(f, "t{}", tv.id),
319            },
320            TypeKind::Con(c) => write!(f, "{}", c.name),
321            TypeKind::App(l, r) => {
322                // Internally `Result` is represented as `Result err ok` so it can be partially
323                // applied as `Result err` for HKTs (Functor/Monad/etc).
324                //
325                // User-facing syntax is `Result ok err` (Rust-style), so render the fully
326                // applied form with swapped arguments.
327                if let TypeKind::App(head, err) = l.as_ref()
328                    && matches!(
329                        head.as_ref(),
330                        TypeKind::Con(c)
331                            if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
332                    )
333                {
334                    return write!(f, "(Result {} {})", r, err);
335                }
336                write!(f, "({} {})", l, r)
337            }
338            TypeKind::Fun(a, b) => write!(f, "({} -> {})", a, b),
339            TypeKind::Tuple(elems) => {
340                write!(f, "(")?;
341                for (i, t) in elems.iter().enumerate() {
342                    write!(f, "{}", t)?;
343                    if i + 1 < elems.len() {
344                        write!(f, ", ")?;
345                    }
346                }
347                write!(f, ")")
348            }
349            TypeKind::Record(fields) => {
350                write!(f, "{{")?;
351                for (i, (name, ty)) in fields.iter().enumerate() {
352                    write!(f, "{}: {}", name, ty)?;
353                    if i + 1 < fields.len() {
354                        write!(f, ", ")?;
355                    }
356                }
357                write!(f, "}}")
358            }
359        }
360    }
361}
362
363#[derive(Clone, Debug, PartialEq, Eq, Hash)]
364pub struct Predicate {
365    pub class: Symbol,
366    pub typ: Type,
367}
368
369impl Predicate {
370    pub fn new(class: impl AsRef<str>, typ: Type) -> Self {
371        Self {
372            class: intern(class.as_ref()),
373            typ,
374        }
375    }
376}
377
378#[derive(Clone, Debug, PartialEq)]
379pub struct Scheme {
380    pub vars: Vec<TypeVar>,
381    pub preds: Vec<Predicate>,
382    pub typ: Type,
383}
384
385impl Scheme {
386    pub fn new(vars: Vec<TypeVar>, preds: Vec<Predicate>, typ: Type) -> Self {
387        Self { vars, preds, typ }
388    }
389}
390
391pub type Subst = HashTrieMapSync<TypeVarId, Type>;
392
393pub trait Types: Sized {
394    fn apply(&self, s: &Subst) -> Self;
395    fn ftv(&self) -> HashSet<TypeVarId>;
396}
397
398impl Types for Type {
399    fn apply(&self, s: &Subst) -> Self {
400        self.apply_with_change(s).0
401    }
402
403    fn ftv(&self) -> HashSet<TypeVarId> {
404        let mut out = HashSet::new();
405        let mut stack: Vec<&Type> = vec![self];
406        while let Some(t) = stack.pop() {
407            match t.as_ref() {
408                TypeKind::Var(tv) => {
409                    out.insert(tv.id);
410                }
411                TypeKind::Con(_) => {}
412                TypeKind::App(l, r) => {
413                    stack.push(l);
414                    stack.push(r);
415                }
416                TypeKind::Fun(a, b) => {
417                    stack.push(a);
418                    stack.push(b);
419                }
420                TypeKind::Tuple(ts) => {
421                    for t in ts {
422                        stack.push(t);
423                    }
424                }
425                TypeKind::Record(fields) => {
426                    for (_, ty) in fields {
427                        stack.push(ty);
428                    }
429                }
430            }
431        }
432        out
433    }
434}
435
436impl Types for Predicate {
437    fn apply(&self, s: &Subst) -> Self {
438        Predicate {
439            class: self.class.clone(),
440            typ: self.typ.apply(s),
441        }
442    }
443
444    fn ftv(&self) -> HashSet<TypeVarId> {
445        self.typ.ftv()
446    }
447}
448
449impl Types for Scheme {
450    fn apply(&self, s: &Subst) -> Self {
451        let mut s_pruned = Subst::new_sync();
452        for (k, v) in s.iter() {
453            if !self.vars.iter().any(|var| var.id == *k) {
454                s_pruned = s_pruned.insert(*k, v.clone());
455            }
456        }
457        Scheme::new(
458            self.vars.clone(),
459            self.preds.iter().map(|p| p.apply(&s_pruned)).collect(),
460            self.typ.apply(&s_pruned),
461        )
462    }
463
464    fn ftv(&self) -> HashSet<TypeVarId> {
465        let mut ftv = self.typ.ftv();
466        for p in &self.preds {
467            ftv.extend(p.ftv());
468        }
469        for v in &self.vars {
470            ftv.remove(&v.id);
471        }
472        ftv
473    }
474}
475
476impl<T: Types> Types for Vec<T> {
477    fn apply(&self, s: &Subst) -> Self {
478        self.iter().map(|t| t.apply(s)).collect()
479    }
480
481    fn ftv(&self) -> HashSet<TypeVarId> {
482        self.iter().flat_map(Types::ftv).collect()
483    }
484}
485
486#[derive(Clone, Debug, PartialEq)]
487pub struct TypedExpr {
488    pub typ: Type,
489    pub kind: TypedExprKind,
490}
491
492impl TypedExpr {
493    pub fn new(typ: Type, kind: TypedExprKind) -> Self {
494        Self { typ, kind }
495    }
496
497    pub fn apply(&self, s: &Subst) -> Self {
498        match &self.kind {
499            TypedExprKind::Lam { .. } => {
500                let mut params: Vec<(Symbol, Type)> = Vec::new();
501                let mut cur = self;
502                while let TypedExprKind::Lam { param, body } = &cur.kind {
503                    params.push((param.clone(), cur.typ.apply(s)));
504                    cur = body.as_ref();
505                }
506                let mut out = cur.apply(s);
507                for (param, typ) in params.into_iter().rev() {
508                    out = TypedExpr {
509                        typ,
510                        kind: TypedExprKind::Lam {
511                            param,
512                            body: Box::new(out),
513                        },
514                    };
515                }
516                return out;
517            }
518            TypedExprKind::App(..) => {
519                let mut apps: Vec<(Type, &TypedExpr)> = Vec::new();
520                let mut cur = self;
521                while let TypedExprKind::App(f, x) = &cur.kind {
522                    apps.push((cur.typ.apply(s), x.as_ref()));
523                    cur = f.as_ref();
524                }
525                let mut out = cur.apply(s);
526                for (typ, arg) in apps.into_iter().rev() {
527                    out = TypedExpr {
528                        typ,
529                        kind: TypedExprKind::App(Box::new(out), Box::new(arg.apply(s))),
530                    };
531                }
532                return out;
533            }
534            _ => {}
535        }
536
537        let typ = self.typ.apply(s);
538        let kind = match &self.kind {
539            TypedExprKind::Bool(v) => TypedExprKind::Bool(*v),
540            TypedExprKind::Uint(v) => TypedExprKind::Uint(*v),
541            TypedExprKind::Int(v) => TypedExprKind::Int(*v),
542            TypedExprKind::Float(v) => TypedExprKind::Float(*v),
543            TypedExprKind::String(v) => TypedExprKind::String(v.clone()),
544            TypedExprKind::Uuid(v) => TypedExprKind::Uuid(*v),
545            TypedExprKind::DateTime(v) => TypedExprKind::DateTime(*v),
546            TypedExprKind::Hole => TypedExprKind::Hole,
547            TypedExprKind::Tuple(elems) => {
548                TypedExprKind::Tuple(elems.iter().map(|e| e.apply(s)).collect())
549            }
550            TypedExprKind::List(elems) => {
551                TypedExprKind::List(elems.iter().map(|e| e.apply(s)).collect())
552            }
553            TypedExprKind::Dict(kvs) => {
554                let mut out = BTreeMap::new();
555                for (k, v) in kvs {
556                    out.insert(k.clone(), v.apply(s));
557                }
558                TypedExprKind::Dict(out)
559            }
560            TypedExprKind::RecordUpdate { base, updates } => {
561                let mut out = BTreeMap::new();
562                for (k, v) in updates {
563                    out.insert(k.clone(), v.apply(s));
564                }
565                TypedExprKind::RecordUpdate {
566                    base: Box::new(base.apply(s)),
567                    updates: out,
568                }
569            }
570            TypedExprKind::Var { name, overloads } => TypedExprKind::Var {
571                name: name.clone(),
572                overloads: overloads.iter().map(|t| t.apply(s)).collect(),
573            },
574            TypedExprKind::App(f, x) => {
575                TypedExprKind::App(Box::new(f.apply(s)), Box::new(x.apply(s)))
576            }
577            TypedExprKind::Project { expr, field } => TypedExprKind::Project {
578                expr: Box::new(expr.apply(s)),
579                field: field.clone(),
580            },
581            TypedExprKind::Lam { param, body } => TypedExprKind::Lam {
582                param: param.clone(),
583                body: Box::new(body.apply(s)),
584            },
585            TypedExprKind::Let { name, def, body } => TypedExprKind::Let {
586                name: name.clone(),
587                def: Box::new(def.apply(s)),
588                body: Box::new(body.apply(s)),
589            },
590            TypedExprKind::LetRec { bindings, body } => TypedExprKind::LetRec {
591                bindings: bindings
592                    .iter()
593                    .map(|(name, def)| (name.clone(), def.apply(s)))
594                    .collect(),
595                body: Box::new(body.apply(s)),
596            },
597            TypedExprKind::Ite {
598                cond,
599                then_expr,
600                else_expr,
601            } => TypedExprKind::Ite {
602                cond: Box::new(cond.apply(s)),
603                then_expr: Box::new(then_expr.apply(s)),
604                else_expr: Box::new(else_expr.apply(s)),
605            },
606            TypedExprKind::Match { scrutinee, arms } => TypedExprKind::Match {
607                scrutinee: Box::new(scrutinee.apply(s)),
608                arms: arms.iter().map(|(p, e)| (p.clone(), e.apply(s))).collect(),
609            },
610        };
611        TypedExpr { typ, kind }
612    }
613}
614
615#[derive(Clone, Debug, PartialEq)]
616pub enum TypedExprKind {
617    Bool(bool),
618    Uint(u64),
619    Int(i64),
620    Float(f64),
621    String(String),
622    Uuid(Uuid),
623    DateTime(DateTime<Utc>),
624    Hole,
625    Tuple(Vec<TypedExpr>),
626    List(Vec<TypedExpr>),
627    Dict(BTreeMap<Symbol, TypedExpr>),
628    RecordUpdate {
629        base: Box<TypedExpr>,
630        updates: BTreeMap<Symbol, TypedExpr>,
631    },
632    Var {
633        name: Symbol,
634        overloads: Vec<Type>,
635    },
636    App(Box<TypedExpr>, Box<TypedExpr>),
637    Project {
638        expr: Box<TypedExpr>,
639        field: Symbol,
640    },
641    Lam {
642        param: Symbol,
643        body: Box<TypedExpr>,
644    },
645    Let {
646        name: Symbol,
647        def: Box<TypedExpr>,
648        body: Box<TypedExpr>,
649    },
650    LetRec {
651        bindings: Vec<(Symbol, TypedExpr)>,
652        body: Box<TypedExpr>,
653    },
654    Ite {
655        cond: Box<TypedExpr>,
656        then_expr: Box<TypedExpr>,
657        else_expr: Box<TypedExpr>,
658    },
659    Match {
660        scrutinee: Box<TypedExpr>,
661        arms: Vec<(Pattern, TypedExpr)>,
662    },
663}
664
665/// Compose substitutions `a` after `b`.
666///
667/// If `t.apply(&b)` is “apply `b` first”, then:
668/// `t.apply(&compose_subst(a, b)) == t.apply(&b).apply(&a)`.
669pub fn compose_subst(a: Subst, b: Subst) -> Subst {
670    if subst_is_empty(&a) {
671        return b;
672    }
673    if subst_is_empty(&b) {
674        return a;
675    }
676    let mut res = Subst::new_sync();
677    for (k, v) in b.iter() {
678        res = res.insert(*k, v.apply(&a));
679    }
680    for (k, v) in a.iter() {
681        res = res.insert(*k, v.clone());
682    }
683    res
684}
685
686fn subst_is_empty(s: &Subst) -> bool {
687    s.iter().next().is_none()
688}
689
690#[derive(Debug, thiserror::Error, PartialEq, Eq)]
691pub enum TypeError {
692    #[error("types do not unify: {0} vs {1}")]
693    Unification(String, String),
694    #[error("occurs check failed for {0} in {1}")]
695    Occurs(TypeVarId, String),
696    #[error("unknown class {0}")]
697    UnknownClass(Symbol),
698    #[error("no instance for {0} {1}")]
699    NoInstance(Symbol, String),
700    #[error("unknown type {0}")]
701    UnknownTypeName(Symbol),
702    #[error("cannot redefine reserved builtin type `{0}`")]
703    ReservedTypeName(Symbol),
704    #[error("duplicate value definition `{0}`")]
705    DuplicateValue(Symbol),
706    #[error("duplicate class definition `{0}`")]
707    DuplicateClass(Symbol),
708    #[error("class `{class}` must have at least one type parameter (got {got})")]
709    InvalidClassArity { class: Symbol, got: usize },
710    #[error("duplicate class method `{0}`")]
711    DuplicateClassMethod(Symbol),
712    #[error("unknown method `{method}` in instance of class `{class}`")]
713    UnknownInstanceMethod { class: Symbol, method: Symbol },
714    #[error("missing implementation of `{method}` for instance of class `{class}`")]
715    MissingInstanceMethod { class: Symbol, method: Symbol },
716    #[error(
717        "instance method `{method}` requires constraint {class} {typ}, but it is not in the instance context"
718    )]
719    MissingInstanceConstraint {
720        method: Symbol,
721        class: Symbol,
722        typ: String,
723    },
724    #[error("unbound variable {0}")]
725    UnknownVar(Symbol),
726    #[error("ambiguous overload for {0}")]
727    AmbiguousOverload(Symbol),
728    #[error("ambiguous type variable(s) {vars:?} in constraints: {constraints}")]
729    AmbiguousTypeVars {
730        vars: Vec<TypeVarId>,
731        constraints: String,
732    },
733    #[error(
734        "kind mismatch for class `{class}`: expected {expected} type argument(s) remaining, got {got} for {typ}"
735    )]
736    KindMismatch {
737        class: Symbol,
738        expected: usize,
739        got: usize,
740        typ: String,
741    },
742    #[error("missing type class constraint(s): {constraints}")]
743    MissingConstraints { constraints: String },
744    #[error("unsupported expression {0}")]
745    UnsupportedExpr(&'static str),
746    #[error("unknown field `{field}` on {typ}")]
747    UnknownField { field: Symbol, typ: String },
748    #[error("field `{field}` is not definitely available on {typ}")]
749    FieldNotKnown { field: Symbol, typ: String },
750    #[error("non-exhaustive match for {typ}: missing {missing:?}")]
751    NonExhaustiveMatch { typ: String, missing: Vec<Symbol> },
752    #[error("at {span}: {error}")]
753    Spanned { span: Span, error: Box<TypeError> },
754    #[error("internal error: {0}")]
755    Internal(String),
756    #[error("{0}")]
757    OutOfGas(#[from] OutOfGas),
758}
759
760fn with_span(span: &Span, err: TypeError) -> TypeError {
761    match err {
762        TypeError::Spanned { .. } => err,
763        other => TypeError::Spanned {
764            span: *span,
765            error: Box::new(other),
766        },
767    }
768}
769
770fn format_constraints_referencing_vars(preds: &[Predicate], vars: &[TypeVarId]) -> String {
771    if vars.is_empty() {
772        return String::new();
773    }
774    let var_set: HashSet<TypeVarId> = vars.iter().copied().collect();
775    let mut parts = Vec::new();
776    for pred in preds {
777        let ftv = pred.ftv();
778        if ftv.iter().any(|v| var_set.contains(v)) {
779            parts.push(format!("{} {}", pred.class, pred.typ));
780        }
781    }
782    if parts.is_empty() {
783        // Fallback: show all constraints if the filtering logic misses something.
784        for pred in preds {
785            parts.push(format!("{} {}", pred.class, pred.typ));
786        }
787    }
788    parts.join(", ")
789}
790
791fn reject_ambiguous_scheme(scheme: &Scheme) -> Result<(), TypeError> {
792    // Only reject *quantified* ambiguous variables. Variables free in the
793    // environment are allowed to appear only in predicates, since they can be
794    // determined by outer context.
795    let quantified: HashSet<TypeVarId> = scheme.vars.iter().map(|v| v.id).collect();
796    if quantified.is_empty() {
797        return Ok(());
798    }
799
800    let typ_ftv = scheme.typ.ftv();
801    let mut vars = HashSet::new();
802    for pred in &scheme.preds {
803        let TypeKind::Var(tv) = pred.typ.as_ref() else {
804            continue;
805        };
806        if quantified.contains(&tv.id) && !typ_ftv.contains(&tv.id) {
807            vars.insert(tv.id);
808        }
809    }
810
811    if vars.is_empty() {
812        return Ok(());
813    }
814    let mut vars: Vec<TypeVarId> = vars.into_iter().collect();
815    vars.sort_unstable();
816    let constraints = format_constraints_referencing_vars(&scheme.preds, &vars);
817    Err(TypeError::AmbiguousTypeVars { vars, constraints })
818}
819
820fn scheme_compatible(existing: &Scheme, declared: &Scheme) -> bool {
821    let s = match unify(&existing.typ, &declared.typ) {
822        Ok(s) => s,
823        Err(_) => return false,
824    };
825
826    let existing_preds = existing.preds.apply(&s);
827    let declared_preds = declared.preds.apply(&s);
828
829    let mut lhs: Vec<(Symbol, String)> = existing_preds
830        .iter()
831        .map(|p| (p.class.clone(), p.typ.to_string()))
832        .collect();
833    let mut rhs: Vec<(Symbol, String)> = declared_preds
834        .iter()
835        .map(|p| (p.class.clone(), p.typ.to_string()))
836        .collect();
837    lhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
838    rhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
839    lhs == rhs
840}
841
842#[derive(Debug)]
843struct Unifier<'g> {
844    // `subs[id] = Some(t)` means type variable `id` has been bound to `t`.
845    //
846    // This is intentionally a dense `Vec` rather than a `HashMap`: inference
847    // generates `TypeVarId`s from a monotonic counter, so the common case is
848    // “small id space, lots of lookups”. This makes the cost model obvious:
849    // you pay O(max_id) space, and you get O(1) binds/queries.
850    subs: Vec<Option<Type>>,
851    gas: Option<&'g mut GasMeter>,
852    max_infer_depth: Option<usize>,
853    infer_depth: usize,
854}
855
856#[derive(Clone, Copy, Debug)]
857pub struct TypeSystemLimits {
858    pub max_infer_depth: Option<usize>,
859}
860
861impl TypeSystemLimits {
862    pub fn unlimited() -> Self {
863        Self {
864            max_infer_depth: None,
865        }
866    }
867
868    pub fn safe_defaults() -> Self {
869        Self {
870            max_infer_depth: Some(4096),
871        }
872    }
873}
874
875impl Default for TypeSystemLimits {
876    fn default() -> Self {
877        Self::safe_defaults()
878    }
879}
880
881fn superclass_closure(class_env: &ClassEnv, given: &[Predicate]) -> Vec<Predicate> {
882    let mut closure: Vec<Predicate> = given.to_vec();
883    let mut i = 0;
884    while i < closure.len() {
885        let p = closure[i].clone();
886        for sup in class_env.supers_of(&p.class) {
887            closure.push(Predicate::new(sup, p.typ.clone()));
888        }
889        i += 1;
890    }
891    closure
892}
893
894fn check_non_ground_predicates_declared(
895    class_env: &ClassEnv,
896    declared: &[Predicate],
897    inferred: &[Predicate],
898) -> Result<(), TypeError> {
899    // Compare by a stable, user-facing rendering (`Default a`, `Foldable t`, ...),
900    // rather than `TypeVarId`, so signature variables that only appear in
901    // predicates (and thus aren't related by unification) still match up.
902    let closure = superclass_closure(class_env, declared);
903    let closure_keys: HashSet<String> = closure
904        .iter()
905        .map(|p| format!("{} {}", p.class, p.typ))
906        .collect();
907    let mut missing = Vec::new();
908    for pred in inferred {
909        if pred.typ.ftv().is_empty() {
910            continue;
911        }
912        let key = format!("{} {}", pred.class, pred.typ);
913        if !closure_keys.contains(&key) {
914            missing.push(key);
915        }
916    }
917
918    missing.sort();
919    missing.dedup();
920    if missing.is_empty() {
921        return Ok(());
922    }
923    Err(TypeError::MissingConstraints {
924        constraints: missing.join(", "),
925    })
926}
927
928fn type_term_remaining_arity(ty: &Type) -> Option<usize> {
929    match ty.as_ref() {
930        TypeKind::Var(_) => None,
931        TypeKind::Con(tc) => Some(tc.arity),
932        TypeKind::App(l, _) => {
933            let a = type_term_remaining_arity(l)?;
934            Some(a.saturating_sub(1))
935        }
936        TypeKind::Fun(..) | TypeKind::Tuple(..) | TypeKind::Record(..) => Some(0),
937    }
938}
939
940fn max_head_app_arity_for_var(ty: &Type, var_id: TypeVarId) -> usize {
941    let mut max_arity = 0usize;
942    let mut stack: Vec<&Type> = vec![ty];
943    while let Some(t) = stack.pop() {
944        match t.as_ref() {
945            TypeKind::Var(_) | TypeKind::Con(_) => {}
946            TypeKind::App(l, r) => {
947                // Record the full application depth at this node.
948                let mut head = t;
949                let mut args = 0usize;
950                while let TypeKind::App(left, _) = head.as_ref() {
951                    args += 1;
952                    head = left;
953                }
954                if let TypeKind::Var(tv) = head.as_ref()
955                    && tv.id == var_id
956                {
957                    max_arity = max_arity.max(args);
958                }
959                stack.push(l);
960                stack.push(r);
961            }
962            TypeKind::Fun(a, b) => {
963                stack.push(a);
964                stack.push(b);
965            }
966            TypeKind::Tuple(ts) => {
967                for t in ts {
968                    stack.push(t);
969                }
970            }
971            TypeKind::Record(fields) => {
972                for (_, t) in fields {
973                    stack.push(t);
974                }
975            }
976        }
977    }
978    max_arity
979}
980
981impl<'g> Unifier<'g> {
982    fn new(max_infer_depth: Option<usize>) -> Self {
983        Self {
984            subs: Vec::new(),
985            gas: None,
986            max_infer_depth,
987            infer_depth: 0,
988        }
989    }
990
991    fn with_gas(gas: &'g mut GasMeter, max_infer_depth: Option<usize>) -> Self {
992        Self {
993            subs: Vec::new(),
994            gas: Some(gas),
995            max_infer_depth,
996            infer_depth: 0,
997        }
998    }
999
1000    fn with_infer_depth<T>(
1001        &mut self,
1002        span: Span,
1003        f: impl FnOnce(&mut Self) -> Result<T, TypeError>,
1004    ) -> Result<T, TypeError> {
1005        if let Some(max) = self.max_infer_depth
1006            && self.infer_depth >= max
1007        {
1008            return Err(TypeError::Spanned {
1009                span,
1010                error: Box::new(TypeError::Internal(format!(
1011                    "maximum inference depth exceeded (max {max})"
1012                ))),
1013            });
1014        }
1015        self.infer_depth += 1;
1016        let res = f(self);
1017        self.infer_depth = self.infer_depth.saturating_sub(1);
1018        res
1019    }
1020
1021    fn charge_infer_node(&mut self) -> Result<(), TypeError> {
1022        let Some(gas) = self.gas.as_mut() else {
1023            return Ok(());
1024        };
1025        let cost = gas.costs.infer_node;
1026        gas.charge(cost)?;
1027        Ok(())
1028    }
1029
1030    fn charge_unify_step(&mut self) -> Result<(), TypeError> {
1031        let Some(gas) = self.gas.as_mut() else {
1032            return Ok(());
1033        };
1034        let cost = gas.costs.unify_step;
1035        gas.charge(cost)?;
1036        Ok(())
1037    }
1038
1039    fn bind_var(&mut self, id: TypeVarId, ty: Type) {
1040        if id >= self.subs.len() {
1041            self.subs.resize(id + 1, None);
1042        }
1043        self.subs[id] = Some(ty);
1044    }
1045
1046    fn prune(&mut self, ty: &Type) -> Type {
1047        match ty.as_ref() {
1048            TypeKind::Var(tv) => {
1049                let bound = self.subs.get(tv.id).and_then(|t| t.clone());
1050                match bound {
1051                    Some(bound) => {
1052                        let pruned = self.prune(&bound);
1053                        self.bind_var(tv.id, pruned.clone());
1054                        pruned
1055                    }
1056                    None => ty.clone(),
1057                }
1058            }
1059            TypeKind::Con(_) => ty.clone(),
1060            TypeKind::App(l, r) => {
1061                let l = self.prune(l);
1062                let r = self.prune(r);
1063                Type::app(l, r)
1064            }
1065            TypeKind::Fun(a, b) => {
1066                let a = self.prune(a);
1067                let b = self.prune(b);
1068                Type::fun(a, b)
1069            }
1070            TypeKind::Tuple(ts) => {
1071                Type::new(TypeKind::Tuple(ts.iter().map(|t| self.prune(t)).collect()))
1072            }
1073            TypeKind::Record(fields) => Type::new(TypeKind::Record(
1074                fields
1075                    .iter()
1076                    .map(|(name, ty)| (name.clone(), self.prune(ty)))
1077                    .collect(),
1078            )),
1079        }
1080    }
1081
1082    fn apply_type(&mut self, ty: &Type) -> Type {
1083        self.prune(ty)
1084    }
1085
1086    fn occurs(&mut self, id: TypeVarId, ty: &Type) -> bool {
1087        match self.prune(ty).as_ref() {
1088            TypeKind::Var(tv) => tv.id == id,
1089            TypeKind::Con(_) => false,
1090            TypeKind::App(l, r) => self.occurs(id, l) || self.occurs(id, r),
1091            TypeKind::Fun(a, b) => self.occurs(id, a) || self.occurs(id, b),
1092            TypeKind::Tuple(ts) => ts.iter().any(|t| self.occurs(id, t)),
1093            TypeKind::Record(fields) => fields.iter().any(|(_, ty)| self.occurs(id, ty)),
1094        }
1095    }
1096
1097    fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
1098        self.charge_unify_step()?;
1099        let t1 = self.prune(t1);
1100        let t2 = self.prune(t2);
1101        match (t1.as_ref(), t2.as_ref()) {
1102            (TypeKind::Var(a), TypeKind::Var(b)) if a.id == b.id => Ok(()),
1103            (TypeKind::Var(tv), other) | (other, TypeKind::Var(tv)) => {
1104                if self.occurs(tv.id, &Type::new(other.clone())) {
1105                    Err(TypeError::Occurs(
1106                        tv.id,
1107                        Type::new(other.clone()).to_string(),
1108                    ))
1109                } else {
1110                    self.bind_var(tv.id, Type::new(other.clone()));
1111                    Ok(())
1112                }
1113            }
1114            (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(()),
1115            (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1116                self.unify(l1, l2)?;
1117                self.unify(r1, r2)
1118            }
1119            (TypeKind::Fun(a1, b1), TypeKind::Fun(a2, b2)) => {
1120                self.unify(a1, a2)?;
1121                self.unify(b1, b2)
1122            }
1123            (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1124                if ts1.len() != ts2.len() {
1125                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1126                }
1127                for (a, b) in ts1.iter().zip(ts2.iter()) {
1128                    self.unify(a, b)?;
1129                }
1130                Ok(())
1131            }
1132            (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1133                if f1.len() != f2.len() {
1134                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1135                }
1136                for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1137                    if n1 != n2 {
1138                        return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1139                    }
1140                    self.unify(t1, t2)?;
1141                }
1142                Ok(())
1143            }
1144            (TypeKind::Record(fields), TypeKind::App(head, arg))
1145            | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1146                TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1147                    let elem_ty = record_elem_type_unifier(fields, self)?;
1148                    self.unify(arg, &elem_ty)
1149                }
1150                TypeKind::Var(tv) => {
1151                    self.unify(
1152                        &Type::new(TypeKind::Var(tv.clone())),
1153                        &Type::builtin(BuiltinTypeId::Dict),
1154                    )?;
1155                    let elem_ty = record_elem_type_unifier(fields, self)?;
1156                    self.unify(arg, &elem_ty)
1157                }
1158                _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1159            },
1160            _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1161        }
1162    }
1163
1164    fn into_subst(mut self) -> Subst {
1165        let mut out = Subst::new_sync();
1166        for id in 0..self.subs.len() {
1167            if let Some(ty) = self.subs[id].clone() {
1168                let pruned = self.prune(&ty);
1169                out = out.insert(id, pruned);
1170            }
1171        }
1172        out
1173    }
1174}
1175
1176fn record_elem_type_unifier(
1177    fields: &[(Symbol, Type)],
1178    unifier: &mut Unifier<'_>,
1179) -> Result<Type, TypeError> {
1180    let mut iter = fields.iter();
1181    let first = match iter.next() {
1182        Some((_, ty)) => ty.clone(),
1183        None => return Err(TypeError::UnsupportedExpr("empty record")),
1184    };
1185    for (_, ty) in iter {
1186        unifier.unify(&first, ty)?;
1187    }
1188    Ok(unifier.apply_type(&first))
1189}
1190
1191fn bind(tv: &TypeVar, t: &Type) -> Result<Subst, TypeError> {
1192    if let TypeKind::Var(var) = t.as_ref()
1193        && var.id == tv.id
1194    {
1195        return Ok(Subst::new_sync());
1196    }
1197    if t.ftv().contains(&tv.id) {
1198        Err(TypeError::Occurs(tv.id, t.to_string()))
1199    } else {
1200        Ok(Subst::new_sync().insert(tv.id, t.clone()))
1201    }
1202}
1203
1204fn record_elem_type(fields: &[(Symbol, Type)]) -> Result<(Subst, Type), TypeError> {
1205    let mut iter = fields.iter();
1206    let first = match iter.next() {
1207        Some((_, ty)) => ty.clone(),
1208        None => return Err(TypeError::UnsupportedExpr("empty record")),
1209    };
1210    let mut subst = Subst::new_sync();
1211    let mut current = first;
1212    for (_, ty) in iter {
1213        let s_next = unify(&current.apply(&subst), &ty.apply(&subst))?;
1214        subst = compose_subst(s_next, subst);
1215        current = current.apply(&subst);
1216    }
1217    Ok((subst.clone(), current.apply(&subst)))
1218}
1219
1220/// Compute a most-general unifier for two types.
1221///
1222/// This is the “pure” unifier: it returns an explicit substitution map and is
1223/// easy to read/compose in isolation. The type inference engine uses `Unifier`
1224/// directly to avoid allocating and composing persistent maps at every
1225/// unification step.
1226pub fn unify(t1: &Type, t2: &Type) -> Result<Subst, TypeError> {
1227    match (t1.as_ref(), t2.as_ref()) {
1228        (TypeKind::Fun(l1, r1), TypeKind::Fun(l2, r2)) => {
1229            let s1 = unify(l1, l2)?;
1230            let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1231            Ok(compose_subst(s2, s1))
1232        }
1233        (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1234            if f1.len() != f2.len() {
1235                return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1236            }
1237            let mut subst = Subst::new_sync();
1238            for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1239                if n1 != n2 {
1240                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1241                }
1242                let s_next = unify(&t1.apply(&subst), &t2.apply(&subst))?;
1243                subst = compose_subst(s_next, subst);
1244            }
1245            Ok(subst)
1246        }
1247        (TypeKind::Record(fields), TypeKind::App(head, arg))
1248        | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1249            TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1250                let (s_fields, elem_ty) = record_elem_type(fields)?;
1251                let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1252                Ok(compose_subst(s_arg, s_fields))
1253            }
1254            TypeKind::Var(tv) => {
1255                let s_head = bind(tv, &Type::builtin(BuiltinTypeId::Dict))?;
1256                let arg = arg.apply(&s_head);
1257                let (s_fields, elem_ty) = record_elem_type(fields)?;
1258                let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1259                Ok(compose_subst(s_arg, compose_subst(s_fields, s_head)))
1260            }
1261            _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1262        },
1263        (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1264            let s1 = unify(l1, l2)?;
1265            let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1266            Ok(compose_subst(s2, s1))
1267        }
1268        (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1269            if ts1.len() != ts2.len() {
1270                return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1271            }
1272            let mut s = Subst::new_sync();
1273            for (a, b) in ts1.iter().zip(ts2.iter()) {
1274                let s_next = unify(&a.apply(&s), &b.apply(&s))?;
1275                s = compose_subst(s_next, s);
1276            }
1277            Ok(s)
1278        }
1279        (TypeKind::Var(tv), t) | (t, TypeKind::Var(tv)) => bind(tv, &Type::new(t.clone())),
1280        (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(Subst::new_sync()),
1281        _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1282    }
1283}
1284
1285#[derive(Default, Debug, Clone)]
1286pub struct TypeEnv {
1287    pub values: HashTrieMapSync<Symbol, Vec<Scheme>>,
1288}
1289
1290impl TypeEnv {
1291    pub fn new() -> Self {
1292        Self {
1293            values: HashTrieMapSync::new_sync(),
1294        }
1295    }
1296
1297    pub fn extend(&mut self, name: Symbol, scheme: Scheme) {
1298        self.values = self.values.insert(name, vec![scheme]);
1299    }
1300
1301    pub fn extend_overload(&mut self, name: Symbol, scheme: Scheme) {
1302        let mut schemes = self.values.get(&name).cloned().unwrap_or_default();
1303        schemes.push(scheme);
1304        self.values = self.values.insert(name, schemes);
1305    }
1306
1307    pub fn remove(&mut self, name: &Symbol) {
1308        self.values = self.values.remove(name);
1309    }
1310
1311    pub fn lookup(&self, name: &Symbol) -> Option<&[Scheme]> {
1312        self.values.get(name).map(|schemes| schemes.as_slice())
1313    }
1314}
1315
1316impl Types for TypeEnv {
1317    fn apply(&self, s: &Subst) -> Self {
1318        let mut values = HashTrieMapSync::new_sync();
1319        for (k, v) in self.values.iter() {
1320            let updated = v
1321                .iter()
1322                .map(|scheme| {
1323                    // Most schemes in environments are monomorphic. Don't walk
1324                    // and rebuild trees unless we actually have work to do.
1325                    if scheme.vars.is_empty() && !subst_is_empty(s) {
1326                        scheme.apply(s)
1327                    } else {
1328                        scheme.clone()
1329                    }
1330                })
1331                .collect();
1332            values = values.insert(k.clone(), updated);
1333        }
1334        TypeEnv { values }
1335    }
1336
1337    fn ftv(&self) -> HashSet<TypeVarId> {
1338        self.values
1339            .iter()
1340            .flat_map(|(_, schemes)| schemes.iter().flat_map(Types::ftv))
1341            .collect()
1342    }
1343}
1344
1345#[derive(Default, Debug, Clone)]
1346pub struct TypeVarSupply {
1347    counter: TypeVarId,
1348}
1349
1350impl TypeVarSupply {
1351    pub fn new() -> Self {
1352        Self { counter: 0 }
1353    }
1354
1355    pub fn fresh(&mut self, name_hint: impl Into<Option<Symbol>>) -> TypeVar {
1356        let tv = TypeVar::new(self.counter, name_hint.into());
1357        self.counter += 1;
1358        tv
1359    }
1360}
1361
1362fn is_integral_literal_expr(expr: &Expr) -> bool {
1363    matches!(expr, Expr::Int(..) | Expr::Uint(..))
1364}
1365
1366/// Turn a monotype `typ` (plus constraints `preds`) into a polymorphic `Scheme`
1367/// by quantifying over the type variables not free in `env`.
1368pub fn generalize(env: &TypeEnv, preds: Vec<Predicate>, typ: Type) -> Scheme {
1369    let mut vars: Vec<TypeVar> = typ
1370        .ftv()
1371        .union(&preds.ftv())
1372        .copied()
1373        .collect::<HashSet<_>>()
1374        .difference(&env.ftv())
1375        .cloned()
1376        .map(|id| TypeVar::new(id, None))
1377        .collect();
1378    vars.sort_by_key(|v| v.id);
1379    Scheme::new(vars, preds, typ)
1380}
1381
1382pub fn instantiate(scheme: &Scheme, supply: &mut TypeVarSupply) -> (Vec<Predicate>, Type) {
1383    // Instantiate replaces all quantified variables with fresh unification
1384    // variables, preserving the original name as a debugging hint.
1385    let mut subst = Subst::new_sync();
1386    for v in &scheme.vars {
1387        subst = subst.insert(v.id, Type::var(supply.fresh(v.name.clone())));
1388    }
1389    (scheme.preds.apply(&subst), scheme.typ.apply(&subst))
1390}
1391
1392/// A named type parameter for an ADT (e.g. `a` in `List a`).
1393#[derive(Clone, Debug)]
1394pub struct AdtParam {
1395    pub name: Symbol,
1396    pub var: TypeVar,
1397}
1398
1399/// A single ADT variant with zero or more constructor arguments.
1400#[derive(Clone, Debug)]
1401pub struct AdtVariant {
1402    pub name: Symbol,
1403    pub args: Vec<Type>,
1404}
1405
1406/// A type declaration for an algebraic data type.
1407///
1408/// This only describes the *type* surface (params + variants). It does not
1409/// introduce any runtime values by itself. Runtime values are created by
1410/// injecting constructor schemes into the environment (see `inject_adt`).
1411#[derive(Clone, Debug)]
1412pub struct AdtDecl {
1413    pub name: Symbol,
1414    pub params: Vec<AdtParam>,
1415    pub variants: Vec<AdtVariant>,
1416}
1417
1418impl AdtDecl {
1419    pub fn new(name: &Symbol, param_names: &[Symbol], supply: &mut TypeVarSupply) -> Self {
1420        let params = param_names
1421            .iter()
1422            .map(|p| AdtParam {
1423                name: p.clone(),
1424                var: supply.fresh(Some(p.clone())),
1425            })
1426            .collect();
1427        Self {
1428            name: name.clone(),
1429            params,
1430            variants: Vec::new(),
1431        }
1432    }
1433
1434    pub fn param_type(&self, name: &Symbol) -> Option<Type> {
1435        self.params
1436            .iter()
1437            .find(|p| &p.name == name)
1438            .map(|p| Type::var(p.var.clone()))
1439    }
1440
1441    pub fn add_variant(&mut self, name: Symbol, args: Vec<Type>) {
1442        self.variants.push(AdtVariant { name, args });
1443    }
1444
1445    pub fn result_type(&self) -> Type {
1446        let mut ty = Type::con(&self.name, self.params.len());
1447        for param in &self.params {
1448            ty = Type::app(ty, Type::var(param.var.clone()));
1449        }
1450        ty
1451    }
1452
1453    /// Build constructor schemes of the form:
1454    /// `C :: a1 -> a2 -> ... -> T params`.
1455    pub fn constructor_schemes(&self) -> Vec<(Symbol, Scheme)> {
1456        let result_ty = self.result_type();
1457        let vars: Vec<TypeVar> = self.params.iter().map(|p| p.var.clone()).collect();
1458        let mut out = Vec::new();
1459        for variant in &self.variants {
1460            let mut typ = result_ty.clone();
1461            for arg in variant.args.iter().rev() {
1462                typ = Type::fun(arg.clone(), typ);
1463            }
1464            out.push((variant.name.clone(), Scheme::new(vars.clone(), vec![], typ)));
1465        }
1466        out
1467    }
1468}
1469
1470#[derive(Clone, Debug)]
1471pub struct Class {
1472    pub supers: Vec<Symbol>,
1473}
1474
1475impl Class {
1476    pub fn new(supers: Vec<Symbol>) -> Self {
1477        Self { supers }
1478    }
1479}
1480
1481#[derive(Clone, Debug)]
1482pub struct Instance {
1483    pub context: Vec<Predicate>,
1484    pub head: Predicate,
1485}
1486
1487impl Instance {
1488    pub fn new(context: Vec<Predicate>, head: Predicate) -> Self {
1489        Self { context, head }
1490    }
1491}
1492
1493#[derive(Default, Debug, Clone)]
1494pub struct ClassEnv {
1495    pub classes: HashMap<Symbol, Class>,
1496    pub instances: HashMap<Symbol, Vec<Instance>>,
1497}
1498
1499impl ClassEnv {
1500    pub fn new() -> Self {
1501        Self {
1502            classes: HashMap::new(),
1503            instances: HashMap::new(),
1504        }
1505    }
1506
1507    pub fn add_class(&mut self, name: Symbol, supers: Vec<Symbol>) {
1508        self.classes.insert(name, Class::new(supers));
1509    }
1510
1511    pub fn add_instance(&mut self, class: Symbol, inst: Instance) {
1512        self.instances.entry(class).or_default().push(inst);
1513    }
1514
1515    pub fn supers_of(&self, class: &Symbol) -> Vec<Symbol> {
1516        self.classes
1517            .get(class)
1518            .map(|c| c.supers.clone())
1519            .unwrap_or_default()
1520    }
1521}
1522
1523pub fn entails(
1524    class_env: &ClassEnv,
1525    given: &[Predicate],
1526    pred: &Predicate,
1527) -> Result<bool, TypeError> {
1528    // Expand given with superclasses.
1529    let mut closure: Vec<Predicate> = given.to_vec();
1530    let mut i = 0;
1531    while i < closure.len() {
1532        let p = closure[i].clone();
1533        for sup in class_env.supers_of(&p.class) {
1534            closure.push(Predicate::new(sup, p.typ.clone()));
1535        }
1536        i += 1;
1537    }
1538
1539    if closure
1540        .iter()
1541        .any(|p| p.class == pred.class && p.typ == pred.typ)
1542    {
1543        return Ok(true);
1544    }
1545
1546    if !class_env.classes.contains_key(&pred.class) {
1547        return Err(TypeError::UnknownClass(pred.class.clone()));
1548    }
1549
1550    if let Some(instances) = class_env.instances.get(&pred.class) {
1551        for inst in instances {
1552            if let Ok(s) = unify(&inst.head.typ, &pred.typ) {
1553                let ctx = inst.context.apply(&s);
1554                if ctx
1555                    .iter()
1556                    .all(|c| entails(class_env, &closure, c).unwrap_or(false))
1557                {
1558                    return Ok(true);
1559                }
1560            }
1561        }
1562    }
1563    Ok(false)
1564}
1565
1566#[derive(Default, Debug, Clone)]
1567pub struct TypeSystem {
1568    pub env: TypeEnv,
1569    pub classes: ClassEnv,
1570    pub adts: HashMap<Symbol, AdtDecl>,
1571    pub class_info: HashMap<Symbol, ClassInfo>,
1572    pub class_methods: HashMap<Symbol, ClassMethodInfo>,
1573    /// Names introduced by `declare fn` (forward declarations).
1574    ///
1575    /// These are placeholders in the type environment and must not block a later
1576    /// real definition (e.g. `fn foo = ...` or host/CLI injection).
1577    pub declared_values: HashSet<Symbol>,
1578    pub supply: TypeVarSupply,
1579    limits: TypeSystemLimits,
1580}
1581
1582/// Semantic information about a type class declaration, derived from Rex source.
1583///
1584/// Design notes (WARM):
1585/// - We keep this explicit and data-oriented: it makes review easy and keeps costs visible.
1586/// - Rex represents multi-parameter classes by encoding the parameters as a tuple in the
1587///   single `Predicate.typ` slot. For a unary class `C a` the predicate is `C a`. For a
1588///   binary class `C t a` the predicate is `C (t, a)`, etc.
1589/// - This keeps the runtime/type-inference machinery simple: instance matching is still
1590///   “unify the predicate types”, and no separate arity tracking is needed.
1591#[derive(Clone, Debug)]
1592pub struct ClassInfo {
1593    pub name: Symbol,
1594    pub params: Vec<Symbol>,
1595    pub supers: Vec<Symbol>,
1596    pub methods: BTreeMap<Symbol, Scheme>,
1597}
1598
1599#[derive(Clone, Debug)]
1600pub struct ClassMethodInfo {
1601    pub class: Symbol,
1602    pub scheme: Scheme,
1603}
1604
1605#[derive(Clone, Debug)]
1606pub struct PreparedInstanceDecl {
1607    pub span: Span,
1608    pub class: Symbol,
1609    pub head: Type,
1610    pub context: Vec<Predicate>,
1611}
1612
1613impl TypeSystem {
1614    pub fn new() -> Self {
1615        Self {
1616            env: TypeEnv::new(),
1617            classes: ClassEnv::new(),
1618            adts: HashMap::new(),
1619            class_info: HashMap::new(),
1620            class_methods: HashMap::new(),
1621            declared_values: HashSet::new(),
1622            supply: TypeVarSupply::new(),
1623            limits: TypeSystemLimits::default(),
1624        }
1625    }
1626
1627    pub fn fresh_type_var(&mut self, name: Option<Symbol>) -> TypeVar {
1628        self.supply.fresh(name)
1629    }
1630
1631    pub fn set_limits(&mut self, limits: TypeSystemLimits) {
1632        self.limits = limits;
1633    }
1634
1635    pub fn new_with_prelude() -> Result<Self, TypeError> {
1636        let mut ts = TypeSystem::new();
1637        prelude::build_prelude(&mut ts)?;
1638        Ok(ts)
1639    }
1640
1641    fn register_decl(&mut self, decl: &Decl) -> Result<(), TypeError> {
1642        match decl {
1643            Decl::Type(ty) => self.register_type_decl(ty),
1644            Decl::Class(class_decl) => self.register_class_decl(class_decl),
1645            Decl::Instance(inst_decl) => {
1646                let _ = self.register_instance_decl(inst_decl)?;
1647                Ok(())
1648            }
1649            Decl::Fn(fd) => self.register_fn_decls(std::slice::from_ref(fd)),
1650            Decl::DeclareFn(fd) => self.inject_declare_fn_decl(fd),
1651            Decl::Import(..) => Ok(()),
1652        }
1653    }
1654
1655    pub fn register_decls(&mut self, decls: &[Decl]) -> Result<(), TypeError> {
1656        let mut pending_fns: Vec<FnDecl> = Vec::new();
1657        for decl in decls {
1658            if let Decl::Fn(fd) = decl {
1659                pending_fns.push(fd.clone());
1660                continue;
1661            }
1662
1663            if !pending_fns.is_empty() {
1664                self.register_fn_decls(&pending_fns)?;
1665                pending_fns.clear();
1666            }
1667
1668            self.register_decl(decl)?;
1669        }
1670        if !pending_fns.is_empty() {
1671            self.register_fn_decls(&pending_fns)?;
1672        }
1673        Ok(())
1674    }
1675
1676    pub fn add_value(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1677        let name = sym(name.as_ref());
1678        self.declared_values.remove(&name);
1679        self.env.extend(name, scheme);
1680    }
1681
1682    pub fn add_overload(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1683        let name = sym(name.as_ref());
1684        self.declared_values.remove(&name);
1685        self.env.extend_overload(name, scheme);
1686    }
1687
1688    pub fn register_instance(&mut self, class: impl AsRef<str>, inst: Instance) {
1689        self.classes.add_instance(sym(class.as_ref()), inst);
1690    }
1691
1692    pub fn register_class_decl(&mut self, decl: &ClassDecl) -> Result<(), TypeError> {
1693        let span = decl.span;
1694        (|| {
1695            // Classes are global, and Rex does not support reopening/merging them.
1696            // Allowing that would be a long-term maintenance hazard: it creates
1697            // spooky-action-at-a-distance across modules and makes reviews harder.
1698            if self.class_info.contains_key(&decl.name)
1699                || self.classes.classes.contains_key(&decl.name)
1700            {
1701                return Err(TypeError::DuplicateClass(decl.name.clone()));
1702            }
1703            if decl.params.is_empty() {
1704                return Err(TypeError::InvalidClassArity {
1705                    class: decl.name.clone(),
1706                    got: decl.params.len(),
1707                });
1708            }
1709            let params = decl.params.clone();
1710
1711            // Register the superclass relationships in the class environment.
1712            //
1713            // We only accept `<= C param` style superclasses for now. Anything
1714            // fancier would require storing type-level relationships in `ClassEnv`,
1715            // which Rex does not currently model.
1716            let mut supers = Vec::with_capacity(decl.supers.len());
1717            if !decl.supers.is_empty() && params.len() != 1 {
1718                return Err(TypeError::UnsupportedExpr(
1719                    "multi-parameter classes cannot declare superclasses yet",
1720                ));
1721            }
1722            for sup in &decl.supers {
1723                let mut vars = HashMap::new();
1724                let param = params[0].clone();
1725                let param_tv = self.supply.fresh(Some(param.clone()));
1726                vars.insert(param, param_tv.clone());
1727                let sup_ty = type_from_annotation_expr_vars(
1728                    &self.adts,
1729                    &sup.typ,
1730                    &mut vars,
1731                    &mut self.supply,
1732                )?;
1733                if sup_ty != Type::var(param_tv) {
1734                    return Err(TypeError::UnsupportedExpr(
1735                        "superclass constraints must be of the form `<= C a`",
1736                    ));
1737                }
1738                supers.push(sup.class.to_dotted_symbol());
1739            }
1740
1741            self.classes.add_class(decl.name.clone(), supers.clone());
1742
1743            let mut methods = BTreeMap::new();
1744            for ClassMethodSig { name, typ } in &decl.methods {
1745                if self.env.lookup(name).is_some() || self.class_methods.contains_key(name) {
1746                    return Err(TypeError::DuplicateClassMethod(name.clone()));
1747                }
1748
1749                let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1750                let mut param_tvs: Vec<TypeVar> = Vec::with_capacity(params.len());
1751                for param in &params {
1752                    let tv = self.supply.fresh(Some(param.clone()));
1753                    vars.insert(param.clone(), tv.clone());
1754                    param_tvs.push(tv);
1755                }
1756
1757                let ty =
1758                    type_from_annotation_expr_vars(&self.adts, typ, &mut vars, &mut self.supply)?;
1759
1760                let mut scheme_vars: Vec<TypeVar> = vars.values().cloned().collect();
1761                scheme_vars.sort_by_key(|tv| tv.id);
1762                scheme_vars.dedup_by_key(|tv| tv.id);
1763
1764                let class_pred = Predicate {
1765                    class: decl.name.clone(),
1766                    typ: if param_tvs.len() == 1 {
1767                        Type::var(param_tvs[0].clone())
1768                    } else {
1769                        Type::tuple(param_tvs.into_iter().map(Type::var).collect())
1770                    },
1771                };
1772                let scheme = Scheme::new(scheme_vars, vec![class_pred], ty);
1773
1774                self.env.extend(name.clone(), scheme.clone());
1775                self.class_methods.insert(
1776                    name.clone(),
1777                    ClassMethodInfo {
1778                        class: decl.name.clone(),
1779                        scheme: scheme.clone(),
1780                    },
1781                );
1782                methods.insert(name.clone(), scheme);
1783            }
1784
1785            self.class_info.insert(
1786                decl.name.clone(),
1787                ClassInfo {
1788                    name: decl.name.clone(),
1789                    params,
1790                    supers,
1791                    methods,
1792                },
1793            );
1794            Ok(())
1795        })()
1796        .map_err(|err| with_span(&span, err))
1797    }
1798
1799    pub fn register_instance_decl(
1800        &mut self,
1801        decl: &InstanceDecl,
1802    ) -> Result<PreparedInstanceDecl, TypeError> {
1803        let span = decl.span;
1804        (|| {
1805            let class = decl.class.clone();
1806            if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
1807                return Err(TypeError::UnknownClass(class));
1808            }
1809
1810            let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1811            let head = type_from_annotation_expr_vars(
1812                &self.adts,
1813                &decl.head,
1814                &mut vars,
1815                &mut self.supply,
1816            )?;
1817            let context = predicates_from_constraints(
1818                &self.adts,
1819                &decl.context,
1820                &mut vars,
1821                &mut self.supply,
1822            )?;
1823
1824            let inst = Instance::new(
1825                context.clone(),
1826                Predicate {
1827                    class: decl.class.clone(),
1828                    typ: head.clone(),
1829                },
1830            );
1831
1832            // Validate method list against the class declaration if present.
1833            if let Some(info) = self.class_info.get(&decl.class) {
1834                for method in &decl.methods {
1835                    if !info.methods.contains_key(&method.name) {
1836                        return Err(TypeError::UnknownInstanceMethod {
1837                            class: decl.class.clone(),
1838                            method: method.name.clone(),
1839                        });
1840                    }
1841                }
1842                for method_name in info.methods.keys() {
1843                    if !decl.methods.iter().any(|m| &m.name == method_name) {
1844                        return Err(TypeError::MissingInstanceMethod {
1845                            class: decl.class.clone(),
1846                            method: method_name.clone(),
1847                        });
1848                    }
1849                }
1850            }
1851
1852            self.classes.add_instance(decl.class.clone(), inst);
1853            Ok(PreparedInstanceDecl {
1854                span,
1855                class: decl.class.clone(),
1856                head,
1857                context,
1858            })
1859        })()
1860        .map_err(|err| with_span(&span, err))
1861    }
1862
1863    pub fn prepare_instance_decl(
1864        &mut self,
1865        decl: &InstanceDecl,
1866    ) -> Result<PreparedInstanceDecl, TypeError> {
1867        let span = decl.span;
1868        (|| {
1869            let class = decl.class.clone();
1870            if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
1871                return Err(TypeError::UnknownClass(class));
1872            }
1873
1874            let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1875            let head = type_from_annotation_expr_vars(
1876                &self.adts,
1877                &decl.head,
1878                &mut vars,
1879                &mut self.supply,
1880            )?;
1881            let context = predicates_from_constraints(
1882                &self.adts,
1883                &decl.context,
1884                &mut vars,
1885                &mut self.supply,
1886            )?;
1887
1888            // Validate method list against the class declaration if present.
1889            if let Some(info) = self.class_info.get(&decl.class) {
1890                for method in &decl.methods {
1891                    if !info.methods.contains_key(&method.name) {
1892                        return Err(TypeError::UnknownInstanceMethod {
1893                            class: decl.class.clone(),
1894                            method: method.name.clone(),
1895                        });
1896                    }
1897                }
1898                for method_name in info.methods.keys() {
1899                    if !decl.methods.iter().any(|m| &m.name == method_name) {
1900                        return Err(TypeError::MissingInstanceMethod {
1901                            class: decl.class.clone(),
1902                            method: method_name.clone(),
1903                        });
1904                    }
1905                }
1906            }
1907
1908            Ok(PreparedInstanceDecl {
1909                span,
1910                class: decl.class.clone(),
1911                head,
1912                context,
1913            })
1914        })()
1915        .map_err(|err| with_span(&span, err))
1916    }
1917
1918    pub fn register_fn_decls(&mut self, decls: &[FnDecl]) -> Result<(), TypeError> {
1919        if decls.is_empty() {
1920            return Ok(());
1921        }
1922
1923        let saved_env = self.env.clone();
1924        let saved_declared = self.declared_values.clone();
1925
1926        let result: Result<(), TypeError> = (|| {
1927            #[derive(Clone)]
1928            struct FnInfo {
1929                decl: FnDecl,
1930                expected: Type,
1931                declared_preds: Vec<Predicate>,
1932                scheme: Scheme,
1933                ann_vars: HashMap<Symbol, TypeVar>,
1934            }
1935
1936            let mut infos: Vec<FnInfo> = Vec::with_capacity(decls.len());
1937            let mut seen_names = HashSet::new();
1938
1939            for decl in decls {
1940                let span = decl.span;
1941                let info = (|| {
1942                    let name = &decl.name.name;
1943                    if !seen_names.insert(name.clone()) {
1944                        return Err(TypeError::DuplicateValue(name.clone()));
1945                    }
1946
1947                    if self.env.lookup(name).is_some() {
1948                        if self.declared_values.remove(name) {
1949                            // A forward declaration should not block the real definition.
1950                            self.env.remove(name);
1951                        } else {
1952                            return Err(TypeError::DuplicateValue(name.clone()));
1953                        }
1954                    }
1955
1956                    let mut sig = decl.ret.clone();
1957                    for (_, ann) in decl.params.iter().rev() {
1958                        let span = Span::from_begin_end(ann.span().begin, sig.span().end);
1959                        sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
1960                    }
1961
1962                    let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
1963                    let expected = type_from_annotation_expr_vars(
1964                        &self.adts,
1965                        &sig,
1966                        &mut ann_vars,
1967                        &mut self.supply,
1968                    )?;
1969                    let declared_preds = predicates_from_constraints(
1970                        &self.adts,
1971                        &decl.constraints,
1972                        &mut ann_vars,
1973                        &mut self.supply,
1974                    )?;
1975
1976                    // Validate that declared constraints are well-formed.
1977                    let var_arities: HashMap<TypeVarId, usize> = ann_vars
1978                        .values()
1979                        .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
1980                        .collect();
1981                    for pred in &declared_preds {
1982                        let _ = entails(&self.classes, &[], pred)?;
1983                        let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
1984                        else {
1985                            continue;
1986                        };
1987                        let args: Vec<Type> = if expected_arities.len() == 1 {
1988                            vec![pred.typ.clone()]
1989                        } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
1990                            if parts.len() != expected_arities.len() {
1991                                continue;
1992                            }
1993                            parts.clone()
1994                        } else {
1995                            continue;
1996                        };
1997
1998                        for (arg, expected_arity) in
1999                            args.iter().zip(expected_arities.iter().copied())
2000                        {
2001                            let got =
2002                                type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2003                                    TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2004                                    _ => None,
2005                                });
2006                            let Some(got) = got else {
2007                                continue;
2008                            };
2009                            if got != expected_arity {
2010                                return Err(TypeError::KindMismatch {
2011                                    class: pred.class.clone(),
2012                                    expected: expected_arity,
2013                                    got,
2014                                    typ: arg.to_string(),
2015                                });
2016                            }
2017                        }
2018                    }
2019
2020                    let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2021                    vars.sort_by_key(|v| v.id);
2022                    let scheme = Scheme::new(vars, declared_preds.clone(), expected.clone());
2023                    reject_ambiguous_scheme(&scheme)?;
2024
2025                    Ok(FnInfo {
2026                        decl: decl.clone(),
2027                        expected,
2028                        declared_preds,
2029                        scheme,
2030                        ann_vars,
2031                    })
2032                })();
2033
2034                infos.push(info.map_err(|err| with_span(&span, err))?);
2035            }
2036
2037            // Seed environment with all declared signatures first so fn bodies
2038            // can reference each other recursively (let-rec semantics).
2039            for info in &infos {
2040                self.env
2041                    .extend(info.decl.name.name.clone(), info.scheme.clone());
2042            }
2043
2044            for info in infos {
2045                let span = info.decl.span;
2046                let mut lam_body = info.decl.body.clone();
2047                let mut lam_end = lam_body.span().end;
2048                for (param, ann) in info.decl.params.iter().rev() {
2049                    let lam_constraints = Vec::new();
2050                    let span = Span::from_begin_end(param.span.begin, lam_end);
2051                    lam_body = Arc::new(Expr::Lam(
2052                        span,
2053                        Scope::new_sync(),
2054                        param.clone(),
2055                        Some(ann.clone()),
2056                        lam_constraints,
2057                        lam_body,
2058                    ));
2059                    lam_end = lam_body.span().end;
2060                }
2061
2062                let (typed, preds, inferred) = infer_typed(self, lam_body.as_ref())?;
2063                let s = unify(&inferred, &info.expected)?;
2064                let preds = preds.apply(&s);
2065                let inferred = inferred.apply(&s);
2066                let declared_preds = info.declared_preds.apply(&s);
2067                let expected = info.expected.apply(&s);
2068
2069                // Keep kind checks aligned with existing `inject_fn_decl` logic.
2070                let var_arities: HashMap<TypeVarId, usize> = info
2071                    .ann_vars
2072                    .values()
2073                    .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
2074                    .collect();
2075                for pred in &declared_preds {
2076                    let _ = entails(&self.classes, &[], pred)?;
2077                    let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
2078                    else {
2079                        continue;
2080                    };
2081                    let args: Vec<Type> = if expected_arities.len() == 1 {
2082                        vec![pred.typ.clone()]
2083                    } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2084                        if parts.len() != expected_arities.len() {
2085                            continue;
2086                        }
2087                        parts.clone()
2088                    } else {
2089                        continue;
2090                    };
2091
2092                    for (arg, expected_arity) in args.iter().zip(expected_arities.iter().copied()) {
2093                        let got = type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2094                            TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2095                            _ => None,
2096                        });
2097                        let Some(got) = got else {
2098                            continue;
2099                        };
2100                        if got != expected_arity {
2101                            return Err(with_span(
2102                                &span,
2103                                TypeError::KindMismatch {
2104                                    class: pred.class.clone(),
2105                                    expected: expected_arity,
2106                                    got,
2107                                    typ: arg.to_string(),
2108                                },
2109                            ));
2110                        }
2111                    }
2112                }
2113
2114                check_non_ground_predicates_declared(&self.classes, &declared_preds, &preds)
2115                    .map_err(|err| with_span(&span, err))?;
2116
2117                let _ = inferred;
2118                let _ = typed;
2119            }
2120
2121            Ok(())
2122        })();
2123
2124        if result.is_err() {
2125            self.env = saved_env;
2126            self.declared_values = saved_declared;
2127        }
2128        result
2129    }
2130
2131    pub fn inject_declare_fn_decl(&mut self, decl: &DeclareFnDecl) -> Result<(), TypeError> {
2132        let span = decl.span;
2133        (|| {
2134            // Build the declared signature type.
2135            let mut sig = decl.ret.clone();
2136            for (_, ann) in decl.params.iter().rev() {
2137                let span = Span::from_begin_end(ann.span().begin, sig.span().end);
2138                sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
2139            }
2140
2141            let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
2142            let expected =
2143                type_from_annotation_expr_vars(&self.adts, &sig, &mut ann_vars, &mut self.supply)?;
2144            let declared_preds = predicates_from_constraints(
2145                &self.adts,
2146                &decl.constraints,
2147                &mut ann_vars,
2148                &mut self.supply,
2149            )?;
2150
2151            let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2152            vars.sort_by_key(|v| v.id);
2153            let scheme = Scheme::new(vars, declared_preds, expected);
2154            reject_ambiguous_scheme(&scheme)?;
2155
2156            // Validate referenced classes exist (and are spelled correctly).
2157            for pred in &scheme.preds {
2158                let _ = entails(&self.classes, &[], pred)?;
2159            }
2160
2161            let name = &decl.name.name;
2162
2163            // If there is already a real definition (prelude/host/`fn`), treat
2164            // `declare fn` as documentation only and ignore it.
2165            if self.env.lookup(name).is_some() && !self.declared_values.contains(name) {
2166                return Ok(());
2167            }
2168
2169            if let Some(existing) = self.env.lookup(name) {
2170                if existing.iter().any(|s| scheme_compatible(s, &scheme)) {
2171                    return Ok(());
2172                }
2173                return Err(TypeError::DuplicateValue(decl.name.name.clone()));
2174            }
2175
2176            self.env.extend(decl.name.name.clone(), scheme);
2177            self.declared_values.insert(decl.name.name.clone());
2178            Ok(())
2179        })()
2180        .map_err(|err| with_span(&span, err))
2181    }
2182
2183    pub fn instantiate_class_method_for_head(
2184        &mut self,
2185        class: &Symbol,
2186        method: &Symbol,
2187        head: &Type,
2188    ) -> Result<Type, TypeError> {
2189        let info = self
2190            .class_info
2191            .get(class)
2192            .ok_or_else(|| TypeError::UnknownClass(class.clone()))?;
2193        let scheme = info
2194            .methods
2195            .get(method)
2196            .ok_or_else(|| TypeError::UnknownInstanceMethod {
2197                class: class.clone(),
2198                method: method.clone(),
2199            })?;
2200
2201        let (preds, typ) = instantiate(scheme, &mut self.supply);
2202        let class_pred =
2203            preds
2204                .iter()
2205                .find(|p| &p.class == class)
2206                .ok_or(TypeError::UnsupportedExpr(
2207                    "class method scheme missing class predicate",
2208                ))?;
2209        let s = unify(&class_pred.typ, head)?;
2210        Ok(typ.apply(&s))
2211    }
2212
2213    pub fn typecheck_instance_method(
2214        &mut self,
2215        prepared: &PreparedInstanceDecl,
2216        method: &InstanceMethodImpl,
2217    ) -> Result<TypedExpr, TypeError> {
2218        let expected =
2219            self.instantiate_class_method_for_head(&prepared.class, &method.name, &prepared.head)?;
2220        let (typed, preds, actual) = infer_typed(self, method.body.as_ref())?;
2221        let s = unify(&actual, &expected)?;
2222        let typed = typed.apply(&s);
2223        let preds = preds.apply(&s);
2224
2225        // The only legal “given” constraints inside an instance method are the
2226        // instance context (plus superclass closure, plus the instance head
2227        // itself). We do *not* allow instance
2228        // search for non-ground constraints here, because that would be unsound:
2229        // a type variable would unify with any concrete instance head.
2230        let mut given = prepared.context.clone();
2231
2232        // Allow recursive instance methods (e.g. `Eq (List a)` calling `(==)`
2233        // on the tail). This is dictionary recursion, not instance search.
2234        given.push(Predicate::new(
2235            prepared.class.clone(),
2236            prepared.head.clone(),
2237        ));
2238        let mut i = 0;
2239        while i < given.len() {
2240            let p = given[i].clone();
2241            for sup in self.classes.supers_of(&p.class) {
2242                given.push(Predicate::new(sup, p.typ.clone()));
2243            }
2244            i += 1;
2245        }
2246
2247        for pred in &preds {
2248            if pred.typ.ftv().is_empty() {
2249                if !entails(&self.classes, &given, pred)? {
2250                    return Err(TypeError::NoInstance(
2251                        pred.class.clone(),
2252                        pred.typ.to_string(),
2253                    ));
2254                }
2255            } else if !given
2256                .iter()
2257                .any(|p| p.class == pred.class && p.typ == pred.typ)
2258            {
2259                return Err(TypeError::MissingInstanceConstraint {
2260                    method: method.name.clone(),
2261                    class: pred.class.clone(),
2262                    typ: pred.typ.to_string(),
2263                });
2264            }
2265        }
2266
2267        Ok(typed)
2268    }
2269
2270    /// Register constructor schemes for an ADT in the type environment.
2271    /// This makes constructors (e.g. `Some`, `None`, `Ok`, `Err`) available
2272    /// to the type checker as normal values.
2273    pub fn register_adt(&mut self, adt: &AdtDecl) {
2274        self.adts.insert(adt.name.clone(), adt.clone());
2275        for (name, scheme) in adt.constructor_schemes() {
2276            self.register_value_scheme(&name, scheme);
2277        }
2278    }
2279
2280    pub fn adt_from_decl(&mut self, decl: &TypeDecl) -> Result<AdtDecl, TypeError> {
2281        let mut adt = AdtDecl::new(&decl.name, &decl.params, &mut self.supply);
2282        let mut param_map: HashMap<Symbol, TypeVar> = HashMap::new();
2283        for param in &adt.params {
2284            param_map.insert(param.name.clone(), param.var.clone());
2285        }
2286
2287        for variant in &decl.variants {
2288            let mut args = Vec::new();
2289            for arg in &variant.args {
2290                let ty = self.type_from_expr(decl, &param_map, arg)?;
2291                args.push(ty);
2292            }
2293            adt.add_variant(variant.name.clone(), args);
2294        }
2295        Ok(adt)
2296    }
2297
2298    pub fn register_type_decl(&mut self, decl: &TypeDecl) -> Result<(), TypeError> {
2299        if BuiltinTypeId::from_symbol(&decl.name).is_some() {
2300            return Err(TypeError::ReservedTypeName(decl.name.clone()));
2301        }
2302        let adt = self.adt_from_decl(decl)?;
2303        self.register_adt(&adt);
2304        Ok(())
2305    }
2306
2307    fn type_from_expr(
2308        &mut self,
2309        decl: &TypeDecl,
2310        params: &HashMap<Symbol, TypeVar>,
2311        expr: &TypeExpr,
2312    ) -> Result<Type, TypeError> {
2313        let span = *expr.span();
2314        let res = (|| match expr {
2315            TypeExpr::Name(_, name) => {
2316                let name_sym = name.to_dotted_symbol();
2317                if let Some(tv) = params.get(&name_sym) {
2318                    Ok(Type::var(tv.clone()))
2319                } else {
2320                    let name = normalize_type_name(&name_sym);
2321                    if let Some(arity) = self.type_arity(decl, &name) {
2322                        Ok(Type::con(name, arity))
2323                    } else {
2324                        Err(TypeError::UnknownTypeName(name))
2325                    }
2326                }
2327            }
2328            TypeExpr::App(_, fun, arg) => {
2329                let fty = self.type_from_expr(decl, params, fun)?;
2330                let aty = self.type_from_expr(decl, params, arg)?;
2331                Ok(type_app_with_result_syntax(fty, aty))
2332            }
2333            TypeExpr::Fun(_, arg, ret) => {
2334                let arg_ty = self.type_from_expr(decl, params, arg)?;
2335                let ret_ty = self.type_from_expr(decl, params, ret)?;
2336                Ok(Type::fun(arg_ty, ret_ty))
2337            }
2338            TypeExpr::Tuple(_, elems) => {
2339                let mut out = Vec::new();
2340                for elem in elems {
2341                    out.push(self.type_from_expr(decl, params, elem)?);
2342                }
2343                Ok(Type::tuple(out))
2344            }
2345            TypeExpr::Record(_, fields) => {
2346                let mut out = Vec::new();
2347                for (name, ty) in fields {
2348                    out.push((name.clone(), self.type_from_expr(decl, params, ty)?));
2349                }
2350                Ok(Type::record(out))
2351            }
2352        })();
2353        res.map_err(|err| with_span(&span, err))
2354    }
2355
2356    fn type_arity(&self, decl: &TypeDecl, name: &Symbol) -> Option<usize> {
2357        if &decl.name == name {
2358            return Some(decl.params.len());
2359        }
2360        if let Some(adt) = self.adts.get(name) {
2361            return Some(adt.params.len());
2362        }
2363        BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2364    }
2365
2366    fn register_value_scheme(&mut self, name: &Symbol, scheme: Scheme) {
2367        match self.env.lookup(name) {
2368            None => self.env.extend(name.clone(), scheme),
2369            Some(existing) => {
2370                if existing.iter().any(|s| unify(&s.typ, &scheme.typ).is_ok()) {
2371                    return;
2372                }
2373                self.env.extend_overload(name.clone(), scheme);
2374            }
2375        }
2376    }
2377
2378    fn expected_class_param_arities(&self, class: &Symbol) -> Option<Vec<usize>> {
2379        let info = self.class_info.get(class)?;
2380        let mut out = vec![0usize; info.params.len()];
2381        for scheme in info.methods.values() {
2382            for (idx, param) in info.params.iter().enumerate() {
2383                let Some(tv) = scheme.vars.iter().find(|v| v.name.as_ref() == Some(param)) else {
2384                    continue;
2385                };
2386                out[idx] = out[idx].max(max_head_app_arity_for_var(&scheme.typ, tv.id));
2387            }
2388        }
2389        Some(out)
2390    }
2391
2392    fn check_predicate_kind(&self, pred: &Predicate) -> Result<(), TypeError> {
2393        let Some(expected) = self.expected_class_param_arities(&pred.class) else {
2394            // Host-injected classes (via Rust API) won't have `class_info`.
2395            return Ok(());
2396        };
2397
2398        let args: Vec<Type> = if expected.len() == 1 {
2399            vec![pred.typ.clone()]
2400        } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2401            if parts.len() != expected.len() {
2402                return Ok(());
2403            }
2404            parts.clone()
2405        } else {
2406            return Ok(());
2407        };
2408
2409        for (arg, expected_arity) in args.iter().zip(expected.iter().copied()) {
2410            let Some(got) = type_term_remaining_arity(arg) else {
2411                // If we can't determine the arity (e.g. a bare type var), skip:
2412                // call sites may fix it up, and Rex does not currently do full
2413                // kind inference.
2414                continue;
2415            };
2416            if got != expected_arity {
2417                return Err(TypeError::KindMismatch {
2418                    class: pred.class.clone(),
2419                    expected: expected_arity,
2420                    got,
2421                    typ: arg.to_string(),
2422                });
2423            }
2424        }
2425        Ok(())
2426    }
2427
2428    fn check_predicate_kinds(&self, preds: &[Predicate]) -> Result<(), TypeError> {
2429        for pred in preds {
2430            self.check_predicate_kind(pred)?;
2431        }
2432        Ok(())
2433    }
2434}
2435
2436fn type_from_annotation_expr(
2437    adts: &HashMap<Symbol, AdtDecl>,
2438    expr: &TypeExpr,
2439) -> Result<Type, TypeError> {
2440    let span = *expr.span();
2441    let res = (|| match expr {
2442        TypeExpr::Name(_, name) => {
2443            let name = normalize_type_name(&name.to_dotted_symbol());
2444            match annotation_type_arity(adts, &name) {
2445                Some(arity) => Ok(Type::con(name, arity)),
2446                None => Err(TypeError::UnknownTypeName(name)),
2447            }
2448        }
2449        TypeExpr::App(_, fun, arg) => {
2450            let fty = type_from_annotation_expr(adts, fun)?;
2451            let aty = type_from_annotation_expr(adts, arg)?;
2452            Ok(type_app_with_result_syntax(fty, aty))
2453        }
2454        TypeExpr::Fun(_, arg, ret) => {
2455            let arg_ty = type_from_annotation_expr(adts, arg)?;
2456            let ret_ty = type_from_annotation_expr(adts, ret)?;
2457            Ok(Type::fun(arg_ty, ret_ty))
2458        }
2459        TypeExpr::Tuple(_, elems) => {
2460            let mut out = Vec::new();
2461            for elem in elems {
2462                out.push(type_from_annotation_expr(adts, elem)?);
2463            }
2464            Ok(Type::tuple(out))
2465        }
2466        TypeExpr::Record(_, fields) => {
2467            let mut out = Vec::new();
2468            for (name, ty) in fields {
2469                out.push((name.clone(), type_from_annotation_expr(adts, ty)?));
2470            }
2471            Ok(Type::record(out))
2472        }
2473    })();
2474    res.map_err(|err| with_span(&span, err))
2475}
2476
2477fn type_from_annotation_expr_vars(
2478    adts: &HashMap<Symbol, AdtDecl>,
2479    expr: &TypeExpr,
2480    vars: &mut HashMap<Symbol, TypeVar>,
2481    supply: &mut TypeVarSupply,
2482) -> Result<Type, TypeError> {
2483    let span = *expr.span();
2484    let res = (|| match expr {
2485        TypeExpr::Name(_, name) => {
2486            let name = normalize_type_name(&name.to_dotted_symbol());
2487            if let Some(arity) = annotation_type_arity(adts, &name) {
2488                Ok(Type::con(name, arity))
2489            } else if let Some(tv) = vars.get(&name) {
2490                Ok(Type::var(tv.clone()))
2491            } else {
2492                let is_upper = name
2493                    .chars()
2494                    .next()
2495                    .map(|c| c.is_uppercase())
2496                    .unwrap_or(false);
2497                if is_upper {
2498                    return Err(TypeError::UnknownTypeName(name));
2499                }
2500                let tv = supply.fresh(Some(name.clone()));
2501                vars.insert(name.clone(), tv.clone());
2502                Ok(Type::var(tv))
2503            }
2504        }
2505        TypeExpr::App(_, fun, arg) => {
2506            let fty = type_from_annotation_expr_vars(adts, fun, vars, supply)?;
2507            let aty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2508            Ok(type_app_with_result_syntax(fty, aty))
2509        }
2510        TypeExpr::Fun(_, arg, ret) => {
2511            let arg_ty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2512            let ret_ty = type_from_annotation_expr_vars(adts, ret, vars, supply)?;
2513            Ok(Type::fun(arg_ty, ret_ty))
2514        }
2515        TypeExpr::Tuple(_, elems) => {
2516            let mut out = Vec::new();
2517            for elem in elems {
2518                out.push(type_from_annotation_expr_vars(adts, elem, vars, supply)?);
2519            }
2520            Ok(Type::tuple(out))
2521        }
2522        TypeExpr::Record(_, fields) => {
2523            let mut out = Vec::new();
2524            for (name, ty) in fields {
2525                out.push((
2526                    name.clone(),
2527                    type_from_annotation_expr_vars(adts, ty, vars, supply)?,
2528                ));
2529            }
2530            Ok(Type::record(out))
2531        }
2532    })();
2533    res.map_err(|err| with_span(&span, err))
2534}
2535
2536fn annotation_type_arity(adts: &HashMap<Symbol, AdtDecl>, name: &Symbol) -> Option<usize> {
2537    if let Some(adt) = adts.get(name) {
2538        return Some(adt.params.len());
2539    }
2540    BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2541}
2542
2543fn normalize_type_name(name: &Symbol) -> Symbol {
2544    if name.as_ref() == "str" {
2545        BuiltinTypeId::String.as_symbol()
2546    } else {
2547        name.clone()
2548    }
2549}
2550
2551fn type_app_with_result_syntax(fun: Type, arg: Type) -> Type {
2552    if let TypeKind::App(head, ok) = fun.as_ref()
2553        && matches!(
2554            head.as_ref(),
2555            TypeKind::Con(c)
2556                if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
2557        )
2558    {
2559        return Type::app(Type::app(head.clone(), arg), ok.clone());
2560    }
2561    Type::app(fun, arg)
2562}
2563
2564fn predicates_from_constraints(
2565    adts: &HashMap<Symbol, AdtDecl>,
2566    constraints: &[TypeConstraint],
2567    vars: &mut HashMap<Symbol, TypeVar>,
2568    supply: &mut TypeVarSupply,
2569) -> Result<Vec<Predicate>, TypeError> {
2570    let mut out = Vec::with_capacity(constraints.len());
2571    for constraint in constraints {
2572        let ty = type_from_annotation_expr_vars(adts, &constraint.typ, vars, supply)?;
2573        out.push(Predicate::new(constraint.class.as_ref(), ty));
2574    }
2575    Ok(out)
2576}
2577
2578#[derive(Clone, Debug, PartialEq, Eq)]
2579pub struct AdtConflict {
2580    pub name: Symbol,
2581    pub definitions: Vec<Type>,
2582}
2583
2584#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
2585#[error("conflicting ADT definitions: {conflicts:?}")]
2586pub struct CollectAdtsError {
2587    pub conflicts: Vec<AdtConflict>,
2588}
2589
2590/// Collect all user-defined ADT constructors referenced by the provided types.
2591///
2592/// This walks each type recursively (including nested occurrences), returns a
2593/// deduplicated list of constructor heads, and rejects ambiguous constructor
2594/// names that appear with incompatible definitions.
2595///
2596/// The returned `Type`s are constructor heads (for example `Foo`), suitable
2597/// for passing to embedder utilities that derive `AdtDecl`s from type
2598/// constructors.
2599///
2600/// # Examples
2601///
2602/// ```rust,ignore
2603/// use rex_ts::{collect_adts_in_types, BuiltinTypeId, Type};
2604///
2605/// let types = vec![
2606///     Type::app(Type::user_con("Foo", 1), Type::builtin(BuiltinTypeId::I32)),
2607///     Type::fun(Type::user_con("Bar", 0), Type::user_con("Foo", 1)),
2608/// ];
2609///
2610/// let adts = collect_adts_in_types(types).unwrap();
2611/// assert_eq!(adts, vec![Type::user_con("Foo", 1), Type::user_con("Bar", 0)]);
2612/// ```
2613///
2614/// ```rust,ignore
2615/// use rex_ts::{collect_adts_in_types, Type};
2616///
2617/// let err = collect_adts_in_types(vec![
2618///     Type::user_con("Thing", 1),
2619///     Type::user_con("Thing", 2),
2620/// ])
2621/// .unwrap_err();
2622///
2623/// assert_eq!(err.conflicts.len(), 1);
2624/// assert_eq!(err.conflicts[0].name.as_ref(), "Thing");
2625/// ```
2626pub fn collect_adts_in_types(types: Vec<Type>) -> Result<Vec<Type>, CollectAdtsError> {
2627    fn visit(
2628        typ: &Type,
2629        out: &mut Vec<Type>,
2630        seen: &mut HashSet<Type>,
2631        defs_by_name: &mut BTreeMap<Symbol, Vec<Type>>,
2632    ) {
2633        match typ.as_ref() {
2634            TypeKind::Var(_) => {}
2635            TypeKind::Con(tc) => {
2636                // Builtins are not embeddable ADT declarations.
2637                if tc.builtin_id.is_none() {
2638                    let adt = Type::new(TypeKind::Con(tc.clone()));
2639                    if seen.insert(adt.clone()) {
2640                        out.push(adt.clone());
2641                    }
2642                    let defs = defs_by_name.entry(tc.name.clone()).or_default();
2643                    if !defs.contains(&adt) {
2644                        defs.push(adt);
2645                    }
2646                }
2647            }
2648            TypeKind::App(fun, arg) => {
2649                visit(fun, out, seen, defs_by_name);
2650                visit(arg, out, seen, defs_by_name);
2651            }
2652            TypeKind::Fun(arg, ret) => {
2653                visit(arg, out, seen, defs_by_name);
2654                visit(ret, out, seen, defs_by_name);
2655            }
2656            TypeKind::Tuple(elems) => {
2657                for elem in elems {
2658                    visit(elem, out, seen, defs_by_name);
2659                }
2660            }
2661            TypeKind::Record(fields) => {
2662                for (_name, field_ty) in fields {
2663                    visit(field_ty, out, seen, defs_by_name);
2664                }
2665            }
2666        }
2667    }
2668
2669    let mut out = Vec::new();
2670    let mut seen = HashSet::new();
2671    let mut defs_by_name: BTreeMap<Symbol, Vec<Type>> = BTreeMap::new();
2672    for typ in &types {
2673        visit(typ, &mut out, &mut seen, &mut defs_by_name);
2674    }
2675
2676    let conflicts: Vec<AdtConflict> = defs_by_name
2677        .into_iter()
2678        .filter_map(|(name, definitions)| {
2679            (definitions.len() > 1).then_some(AdtConflict { name, definitions })
2680        })
2681        .collect();
2682    if !conflicts.is_empty() {
2683        return Err(CollectAdtsError { conflicts });
2684    }
2685
2686    Ok(out)
2687}