Skip to main content

rexlang_typesystem/
typesystem.rs

1//! Core type system implementation for Rex.
2
3use std::collections::{BTreeMap, HashMap, HashSet};
4use std::fmt::{self, Display, Formatter};
5use std::sync::Arc;
6
7use chrono::{DateTime, Utc};
8use rexlang_ast::expr::{
9    ClassDecl, ClassMethodSig, Decl, DeclareFnDecl, Expr, FnDecl, InstanceDecl, InstanceMethodImpl,
10    Pattern, Scope, Symbol, TypeConstraint, TypeDecl, TypeExpr, intern, sym,
11};
12use rexlang_lexer::span::Span;
13use rexlang_util::{GasMeter, OutOfGas};
14use rpds::HashTrieMapSync;
15use uuid::Uuid;
16
17use crate::prelude;
18
19#[path = "inference.rs"]
20pub mod inference;
21
22pub use inference::{infer, infer_typed, infer_typed_with_gas, infer_with_gas};
23
24pub type TypeVarId = usize;
25
26#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
27pub enum BuiltinTypeId {
28    U8,
29    U16,
30    U32,
31    U64,
32    I8,
33    I16,
34    I32,
35    I64,
36    F32,
37    F64,
38    Bool,
39    String,
40    Uuid,
41    DateTime,
42    List,
43    Array,
44    Dict,
45    Option,
46    Promise,
47    Result,
48}
49
50impl BuiltinTypeId {
51    pub fn as_symbol(self) -> Symbol {
52        sym(self.as_str())
53    }
54
55    pub fn as_str(self) -> &'static str {
56        match self {
57            Self::U8 => "u8",
58            Self::U16 => "u16",
59            Self::U32 => "u32",
60            Self::U64 => "u64",
61            Self::I8 => "i8",
62            Self::I16 => "i16",
63            Self::I32 => "i32",
64            Self::I64 => "i64",
65            Self::F32 => "f32",
66            Self::F64 => "f64",
67            Self::Bool => "bool",
68            Self::String => "string",
69            Self::Uuid => "uuid",
70            Self::DateTime => "datetime",
71            Self::List => "List",
72            Self::Array => "Array",
73            Self::Dict => "Dict",
74            Self::Option => "Option",
75            Self::Promise => "Promise",
76            Self::Result => "Result",
77        }
78    }
79
80    pub fn arity(self) -> usize {
81        match self {
82            Self::List | Self::Array | Self::Dict | Self::Option | Self::Promise => 1,
83            Self::Result => 2,
84            _ => 0,
85        }
86    }
87
88    pub fn from_symbol(name: &Symbol) -> Option<Self> {
89        Self::from_name(name.as_ref())
90    }
91
92    pub fn from_name(name: &str) -> Option<Self> {
93        match name {
94            "u8" => Some(Self::U8),
95            "u16" => Some(Self::U16),
96            "u32" => Some(Self::U32),
97            "u64" => Some(Self::U64),
98            "i8" => Some(Self::I8),
99            "i16" => Some(Self::I16),
100            "i32" => Some(Self::I32),
101            "i64" => Some(Self::I64),
102            "f32" => Some(Self::F32),
103            "f64" => Some(Self::F64),
104            "bool" => Some(Self::Bool),
105            "string" => Some(Self::String),
106            "uuid" => Some(Self::Uuid),
107            "datetime" => Some(Self::DateTime),
108            "List" => Some(Self::List),
109            "Array" => Some(Self::Array),
110            "Dict" => Some(Self::Dict),
111            "Option" => Some(Self::Option),
112            "Promise" => Some(Self::Promise),
113            "Result" => Some(Self::Result),
114            _ => None,
115        }
116    }
117}
118
119#[derive(Clone, Debug, Eq, Hash, PartialEq)]
120pub struct TypeVar {
121    pub id: TypeVarId,
122    pub name: Option<Symbol>,
123}
124
125impl TypeVar {
126    pub fn new(id: TypeVarId, name: impl Into<Option<Symbol>>) -> Self {
127        Self {
128            id,
129            name: name.into(),
130        }
131    }
132}
133
134#[derive(Clone, Debug, Eq, Hash, PartialEq)]
135pub struct TypeConst {
136    pub name: Symbol,
137    pub arity: usize,
138    pub builtin_id: Option<BuiltinTypeId>,
139}
140
141#[derive(Clone, Debug, PartialEq, Eq, Hash)]
142pub struct Type(Arc<TypeKind>);
143
144#[derive(Clone, Debug, PartialEq, Eq, Hash)]
145pub enum TypeKind {
146    Var(TypeVar),
147    Con(TypeConst),
148    App(Type, Type),
149    Fun(Type, Type),
150    Tuple(Vec<Type>),
151    /// Record type `{a: T, b: U}`.
152    ///
153    /// Invariant: fields are sorted by name. This makes record equality and
154    /// unification a cheap zip over two vectors, and it makes printing stable.
155    Record(Vec<(Symbol, Type)>),
156}
157
158impl Type {
159    pub fn new(kind: TypeKind) -> Self {
160        Type(Arc::new(kind))
161    }
162
163    pub fn con(name: impl AsRef<str>, arity: usize) -> Self {
164        if let Some(id) = BuiltinTypeId::from_name(name.as_ref())
165            && id.arity() == arity
166        {
167            return Self::builtin(id);
168        }
169        Self::user_con(name, arity)
170    }
171
172    pub fn user_con(name: impl AsRef<str>, arity: usize) -> Self {
173        Type::new(TypeKind::Con(TypeConst {
174            name: intern(name.as_ref()),
175            arity,
176            builtin_id: None,
177        }))
178    }
179
180    pub fn builtin(id: BuiltinTypeId) -> Self {
181        Type::new(TypeKind::Con(TypeConst {
182            name: id.as_symbol(),
183            arity: id.arity(),
184            builtin_id: Some(id),
185        }))
186    }
187
188    pub fn var(tv: TypeVar) -> Self {
189        Type::new(TypeKind::Var(tv))
190    }
191
192    pub fn fun(a: Type, b: Type) -> Self {
193        Type::new(TypeKind::Fun(a, b))
194    }
195
196    pub fn app(f: Type, arg: Type) -> Self {
197        Type::new(TypeKind::App(f, arg))
198    }
199
200    pub fn tuple(elems: Vec<Type>) -> Self {
201        Type::new(TypeKind::Tuple(elems))
202    }
203
204    pub fn record(mut fields: Vec<(Symbol, Type)>) -> Self {
205        // Canonicalize records so downstream code can rely on “same shape means
206        // same ordering”. (This is a correctness invariant, not a nicety.)
207        fields.sort_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
208        Type::new(TypeKind::Record(fields))
209    }
210
211    pub fn list(elem: Type) -> Type {
212        Type::app(Type::builtin(BuiltinTypeId::List), elem)
213    }
214
215    pub fn array(elem: Type) -> Type {
216        Type::app(Type::builtin(BuiltinTypeId::Array), elem)
217    }
218
219    pub fn dict(elem: Type) -> Type {
220        Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
221    }
222
223    pub fn option(elem: Type) -> Type {
224        Type::app(Type::builtin(BuiltinTypeId::Option), elem)
225    }
226
227    pub fn promise(elem: Type) -> Type {
228        Type::app(Type::builtin(BuiltinTypeId::Promise), elem)
229    }
230
231    pub fn result(ok: Type, err: Type) -> Type {
232        Type::app(Type::app(Type::builtin(BuiltinTypeId::Result), err), ok)
233    }
234
235    fn apply_with_change(&self, s: &Subst) -> (Type, bool) {
236        match self.as_ref() {
237            TypeKind::Var(tv) => match s.get(&tv.id) {
238                Some(ty) => (ty.clone(), true),
239                None => (self.clone(), false),
240            },
241            TypeKind::Con(_) => (self.clone(), false),
242            TypeKind::App(l, r) => {
243                let (l_new, l_changed) = l.apply_with_change(s);
244                let (r_new, r_changed) = r.apply_with_change(s);
245                if l_changed || r_changed {
246                    (Type::app(l_new, r_new), true)
247                } else {
248                    (self.clone(), false)
249                }
250            }
251            TypeKind::Fun(_, _) => {
252                // Avoid recursive descent on long function chains like
253                // `a1 -> a2 -> ... -> an -> r`.
254                let mut args = Vec::new();
255                let mut changed = false;
256                let mut cur: &Type = self;
257                while let TypeKind::Fun(a, b) = cur.as_ref() {
258                    let (a_new, a_changed) = a.apply_with_change(s);
259                    changed |= a_changed;
260                    args.push(a_new);
261                    cur = b;
262                }
263                let (ret_new, ret_changed) = cur.apply_with_change(s);
264                changed |= ret_changed;
265                if !changed {
266                    return (self.clone(), false);
267                }
268                let mut out = ret_new;
269                for a_new in args.into_iter().rev() {
270                    out = Type::fun(a_new, out);
271                }
272                (out, true)
273            }
274            TypeKind::Tuple(ts) => {
275                let mut changed = false;
276                let mut out = Vec::with_capacity(ts.len());
277                for t in ts {
278                    let (t_new, t_changed) = t.apply_with_change(s);
279                    changed |= t_changed;
280                    out.push(t_new);
281                }
282                if changed {
283                    (Type::new(TypeKind::Tuple(out)), true)
284                } else {
285                    (self.clone(), false)
286                }
287            }
288            TypeKind::Record(fields) => {
289                let mut changed = false;
290                let mut out = Vec::with_capacity(fields.len());
291                for (k, v) in fields {
292                    let (v_new, v_changed) = v.apply_with_change(s);
293                    changed |= v_changed;
294                    out.push((k.clone(), v_new));
295                }
296                if changed {
297                    (Type::new(TypeKind::Record(out)), true)
298                } else {
299                    (self.clone(), false)
300                }
301            }
302        }
303    }
304}
305
306impl AsRef<TypeKind> for Type {
307    fn as_ref(&self) -> &TypeKind {
308        self.0.as_ref()
309    }
310}
311
312impl std::ops::Deref for Type {
313    type Target = TypeKind;
314
315    fn deref(&self) -> &Self::Target {
316        &self.0
317    }
318}
319
320impl Display for Type {
321    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
322        match self.as_ref() {
323            TypeKind::Var(tv) => match &tv.name {
324                Some(name) => write!(f, "'{}", name),
325                None => write!(f, "t{}", tv.id),
326            },
327            TypeKind::Con(c) => write!(f, "{}", c.name),
328            TypeKind::App(l, r) => {
329                // Internally `Result` is represented as `Result err ok` so it can be partially
330                // applied as `Result err` for HKTs (Functor/Monad/etc).
331                //
332                // User-facing syntax is `Result ok err` (Rust-style), so render the fully
333                // applied form with swapped arguments.
334                if let TypeKind::App(head, err) = l.as_ref()
335                    && matches!(
336                        head.as_ref(),
337                        TypeKind::Con(c)
338                            if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
339                    )
340                {
341                    return write!(f, "(Result {} {})", r, err);
342                }
343                write!(f, "({} {})", l, r)
344            }
345            TypeKind::Fun(a, b) => write!(f, "({} -> {})", a, b),
346            TypeKind::Tuple(elems) => {
347                write!(f, "(")?;
348                for (i, t) in elems.iter().enumerate() {
349                    write!(f, "{}", t)?;
350                    if i + 1 < elems.len() {
351                        write!(f, ", ")?;
352                    }
353                }
354                write!(f, ")")
355            }
356            TypeKind::Record(fields) => {
357                write!(f, "{{")?;
358                for (i, (name, ty)) in fields.iter().enumerate() {
359                    write!(f, "{}: {}", name, ty)?;
360                    if i + 1 < fields.len() {
361                        write!(f, ", ")?;
362                    }
363                }
364                write!(f, "}}")
365            }
366        }
367    }
368}
369
370#[derive(Clone, Debug, PartialEq, Eq, Hash)]
371pub struct Predicate {
372    pub class: Symbol,
373    pub typ: Type,
374}
375
376impl Predicate {
377    pub fn new(class: impl AsRef<str>, typ: Type) -> Self {
378        Self {
379            class: intern(class.as_ref()),
380            typ,
381        }
382    }
383}
384
385#[derive(Clone, Debug, PartialEq)]
386pub struct Scheme {
387    pub vars: Vec<TypeVar>,
388    pub preds: Vec<Predicate>,
389    pub typ: Type,
390}
391
392impl Scheme {
393    pub fn new(vars: Vec<TypeVar>, preds: Vec<Predicate>, typ: Type) -> Self {
394        Self { vars, preds, typ }
395    }
396}
397
398pub type Subst = HashTrieMapSync<TypeVarId, Type>;
399
400pub trait Types: Sized {
401    fn apply(&self, s: &Subst) -> Self;
402    fn ftv(&self) -> HashSet<TypeVarId>;
403}
404
405impl Types for Type {
406    fn apply(&self, s: &Subst) -> Self {
407        self.apply_with_change(s).0
408    }
409
410    fn ftv(&self) -> HashSet<TypeVarId> {
411        let mut out = HashSet::new();
412        let mut stack: Vec<&Type> = vec![self];
413        while let Some(t) = stack.pop() {
414            match t.as_ref() {
415                TypeKind::Var(tv) => {
416                    out.insert(tv.id);
417                }
418                TypeKind::Con(_) => {}
419                TypeKind::App(l, r) => {
420                    stack.push(l);
421                    stack.push(r);
422                }
423                TypeKind::Fun(a, b) => {
424                    stack.push(a);
425                    stack.push(b);
426                }
427                TypeKind::Tuple(ts) => {
428                    for t in ts {
429                        stack.push(t);
430                    }
431                }
432                TypeKind::Record(fields) => {
433                    for (_, ty) in fields {
434                        stack.push(ty);
435                    }
436                }
437            }
438        }
439        out
440    }
441}
442
443impl Types for Predicate {
444    fn apply(&self, s: &Subst) -> Self {
445        Predicate {
446            class: self.class.clone(),
447            typ: self.typ.apply(s),
448        }
449    }
450
451    fn ftv(&self) -> HashSet<TypeVarId> {
452        self.typ.ftv()
453    }
454}
455
456impl Types for Scheme {
457    fn apply(&self, s: &Subst) -> Self {
458        let mut s_pruned = Subst::new_sync();
459        for (k, v) in s.iter() {
460            if !self.vars.iter().any(|var| var.id == *k) {
461                s_pruned = s_pruned.insert(*k, v.clone());
462            }
463        }
464        Scheme::new(
465            self.vars.clone(),
466            self.preds.iter().map(|p| p.apply(&s_pruned)).collect(),
467            self.typ.apply(&s_pruned),
468        )
469    }
470
471    fn ftv(&self) -> HashSet<TypeVarId> {
472        let mut ftv = self.typ.ftv();
473        for p in &self.preds {
474            ftv.extend(p.ftv());
475        }
476        for v in &self.vars {
477            ftv.remove(&v.id);
478        }
479        ftv
480    }
481}
482
483impl<T: Types> Types for Vec<T> {
484    fn apply(&self, s: &Subst) -> Self {
485        self.iter().map(|t| t.apply(s)).collect()
486    }
487
488    fn ftv(&self) -> HashSet<TypeVarId> {
489        self.iter().flat_map(Types::ftv).collect()
490    }
491}
492
493#[derive(Clone, Debug, PartialEq)]
494pub struct TypedExpr {
495    pub typ: Type,
496    pub kind: TypedExprKind,
497}
498
499impl TypedExpr {
500    pub fn new(typ: Type, kind: TypedExprKind) -> Self {
501        Self { typ, kind }
502    }
503
504    pub fn apply(&self, s: &Subst) -> Self {
505        match &self.kind {
506            TypedExprKind::Lam { .. } => {
507                let mut params: Vec<(Symbol, Type)> = Vec::new();
508                let mut cur = self;
509                while let TypedExprKind::Lam { param, body } = &cur.kind {
510                    params.push((param.clone(), cur.typ.apply(s)));
511                    cur = body.as_ref();
512                }
513                let mut out = cur.apply(s);
514                for (param, typ) in params.into_iter().rev() {
515                    out = TypedExpr {
516                        typ,
517                        kind: TypedExprKind::Lam {
518                            param,
519                            body: Box::new(out),
520                        },
521                    };
522                }
523                return out;
524            }
525            TypedExprKind::App(..) => {
526                let mut apps: Vec<(Type, &TypedExpr)> = Vec::new();
527                let mut cur = self;
528                while let TypedExprKind::App(f, x) = &cur.kind {
529                    apps.push((cur.typ.apply(s), x.as_ref()));
530                    cur = f.as_ref();
531                }
532                let mut out = cur.apply(s);
533                for (typ, arg) in apps.into_iter().rev() {
534                    out = TypedExpr {
535                        typ,
536                        kind: TypedExprKind::App(Box::new(out), Box::new(arg.apply(s))),
537                    };
538                }
539                return out;
540            }
541            _ => {}
542        }
543
544        let typ = self.typ.apply(s);
545        let kind = match &self.kind {
546            TypedExprKind::Bool(v) => TypedExprKind::Bool(*v),
547            TypedExprKind::Uint(v) => TypedExprKind::Uint(*v),
548            TypedExprKind::Int(v) => TypedExprKind::Int(*v),
549            TypedExprKind::Float(v) => TypedExprKind::Float(*v),
550            TypedExprKind::String(v) => TypedExprKind::String(v.clone()),
551            TypedExprKind::Uuid(v) => TypedExprKind::Uuid(*v),
552            TypedExprKind::DateTime(v) => TypedExprKind::DateTime(*v),
553            TypedExprKind::Hole => TypedExprKind::Hole,
554            TypedExprKind::Tuple(elems) => {
555                TypedExprKind::Tuple(elems.iter().map(|e| e.apply(s)).collect())
556            }
557            TypedExprKind::List(elems) => {
558                TypedExprKind::List(elems.iter().map(|e| e.apply(s)).collect())
559            }
560            TypedExprKind::Dict(kvs) => {
561                let mut out = BTreeMap::new();
562                for (k, v) in kvs {
563                    out.insert(k.clone(), v.apply(s));
564                }
565                TypedExprKind::Dict(out)
566            }
567            TypedExprKind::RecordUpdate { base, updates } => {
568                let mut out = BTreeMap::new();
569                for (k, v) in updates {
570                    out.insert(k.clone(), v.apply(s));
571                }
572                TypedExprKind::RecordUpdate {
573                    base: Box::new(base.apply(s)),
574                    updates: out,
575                }
576            }
577            TypedExprKind::Var { name, overloads } => TypedExprKind::Var {
578                name: name.clone(),
579                overloads: overloads.iter().map(|t| t.apply(s)).collect(),
580            },
581            TypedExprKind::App(f, x) => {
582                TypedExprKind::App(Box::new(f.apply(s)), Box::new(x.apply(s)))
583            }
584            TypedExprKind::Project { expr, field } => TypedExprKind::Project {
585                expr: Box::new(expr.apply(s)),
586                field: field.clone(),
587            },
588            TypedExprKind::Lam { param, body } => TypedExprKind::Lam {
589                param: param.clone(),
590                body: Box::new(body.apply(s)),
591            },
592            TypedExprKind::Let { name, def, body } => TypedExprKind::Let {
593                name: name.clone(),
594                def: Box::new(def.apply(s)),
595                body: Box::new(body.apply(s)),
596            },
597            TypedExprKind::LetRec { bindings, body } => TypedExprKind::LetRec {
598                bindings: bindings
599                    .iter()
600                    .map(|(name, def)| (name.clone(), def.apply(s)))
601                    .collect(),
602                body: Box::new(body.apply(s)),
603            },
604            TypedExprKind::Ite {
605                cond,
606                then_expr,
607                else_expr,
608            } => TypedExprKind::Ite {
609                cond: Box::new(cond.apply(s)),
610                then_expr: Box::new(then_expr.apply(s)),
611                else_expr: Box::new(else_expr.apply(s)),
612            },
613            TypedExprKind::Match { scrutinee, arms } => TypedExprKind::Match {
614                scrutinee: Box::new(scrutinee.apply(s)),
615                arms: arms.iter().map(|(p, e)| (p.clone(), e.apply(s))).collect(),
616            },
617        };
618        TypedExpr { typ, kind }
619    }
620}
621
622#[derive(Clone, Debug, PartialEq)]
623pub enum TypedExprKind {
624    Bool(bool),
625    Uint(u64),
626    Int(i64),
627    Float(f64),
628    String(String),
629    Uuid(Uuid),
630    DateTime(DateTime<Utc>),
631    Hole,
632    Tuple(Vec<TypedExpr>),
633    List(Vec<TypedExpr>),
634    Dict(BTreeMap<Symbol, TypedExpr>),
635    RecordUpdate {
636        base: Box<TypedExpr>,
637        updates: BTreeMap<Symbol, TypedExpr>,
638    },
639    Var {
640        name: Symbol,
641        overloads: Vec<Type>,
642    },
643    App(Box<TypedExpr>, Box<TypedExpr>),
644    Project {
645        expr: Box<TypedExpr>,
646        field: Symbol,
647    },
648    Lam {
649        param: Symbol,
650        body: Box<TypedExpr>,
651    },
652    Let {
653        name: Symbol,
654        def: Box<TypedExpr>,
655        body: Box<TypedExpr>,
656    },
657    LetRec {
658        bindings: Vec<(Symbol, TypedExpr)>,
659        body: Box<TypedExpr>,
660    },
661    Ite {
662        cond: Box<TypedExpr>,
663        then_expr: Box<TypedExpr>,
664        else_expr: Box<TypedExpr>,
665    },
666    Match {
667        scrutinee: Box<TypedExpr>,
668        arms: Vec<(Pattern, TypedExpr)>,
669    },
670}
671
672/// Compose substitutions `a` after `b`.
673///
674/// If `t.apply(&b)` is “apply `b` first”, then:
675/// `t.apply(&compose_subst(a, b)) == t.apply(&b).apply(&a)`.
676pub fn compose_subst(a: Subst, b: Subst) -> Subst {
677    if subst_is_empty(&a) {
678        return b;
679    }
680    if subst_is_empty(&b) {
681        return a;
682    }
683    let mut res = Subst::new_sync();
684    for (k, v) in b.iter() {
685        res = res.insert(*k, v.apply(&a));
686    }
687    for (k, v) in a.iter() {
688        res = res.insert(*k, v.clone());
689    }
690    res
691}
692
693fn subst_is_empty(s: &Subst) -> bool {
694    s.iter().next().is_none()
695}
696
697#[derive(Debug, thiserror::Error, PartialEq, Eq)]
698pub enum TypeError {
699    #[error("types do not unify: {0} vs {1}")]
700    Unification(String, String),
701    #[error("occurs check failed for {0} in {1}")]
702    Occurs(TypeVarId, String),
703    #[error("unknown class {0}")]
704    UnknownClass(Symbol),
705    #[error("no instance for {0} {1}")]
706    NoInstance(Symbol, String),
707    #[error("unknown type {0}")]
708    UnknownTypeName(Symbol),
709    #[error("cannot redefine reserved builtin type `{0}`")]
710    ReservedTypeName(Symbol),
711    #[error("duplicate value definition `{0}`")]
712    DuplicateValue(Symbol),
713    #[error("duplicate class definition `{0}`")]
714    DuplicateClass(Symbol),
715    #[error("class `{class}` must have at least one type parameter (got {got})")]
716    InvalidClassArity { class: Symbol, got: usize },
717    #[error("duplicate class method `{0}`")]
718    DuplicateClassMethod(Symbol),
719    #[error("unknown method `{method}` in instance of class `{class}`")]
720    UnknownInstanceMethod { class: Symbol, method: Symbol },
721    #[error("missing implementation of `{method}` for instance of class `{class}`")]
722    MissingInstanceMethod { class: Symbol, method: Symbol },
723    #[error(
724        "instance method `{method}` requires constraint {class} {typ}, but it is not in the instance context"
725    )]
726    MissingInstanceConstraint {
727        method: Symbol,
728        class: Symbol,
729        typ: String,
730    },
731    #[error("unbound variable {0}")]
732    UnknownVar(Symbol),
733    #[error("ambiguous overload for {0}")]
734    AmbiguousOverload(Symbol),
735    #[error("ambiguous type variable(s) {vars:?} in constraints: {constraints}")]
736    AmbiguousTypeVars {
737        vars: Vec<TypeVarId>,
738        constraints: String,
739    },
740    #[error(
741        "kind mismatch for class `{class}`: expected {expected} type argument(s) remaining, got {got} for {typ}"
742    )]
743    KindMismatch {
744        class: Symbol,
745        expected: usize,
746        got: usize,
747        typ: String,
748    },
749    #[error("missing type class constraint(s): {constraints}")]
750    MissingConstraints { constraints: String },
751    #[error("unsupported expression {0}")]
752    UnsupportedExpr(&'static str),
753    #[error("unknown field `{field}` on {typ}")]
754    UnknownField { field: Symbol, typ: String },
755    #[error("field `{field}` is not definitely available on {typ}")]
756    FieldNotKnown { field: Symbol, typ: String },
757    #[error("non-exhaustive match for {typ}: missing {missing:?}")]
758    NonExhaustiveMatch { typ: String, missing: Vec<Symbol> },
759    #[error("at {span}: {error}")]
760    Spanned { span: Span, error: Box<TypeError> },
761    #[error("internal error: {0}")]
762    Internal(String),
763    #[error("{0}")]
764    OutOfGas(#[from] OutOfGas),
765}
766
767fn with_span(span: &Span, err: TypeError) -> TypeError {
768    match err {
769        TypeError::Spanned { .. } => err,
770        other => TypeError::Spanned {
771            span: *span,
772            error: Box::new(other),
773        },
774    }
775}
776
777fn format_constraints_referencing_vars(preds: &[Predicate], vars: &[TypeVarId]) -> String {
778    if vars.is_empty() {
779        return String::new();
780    }
781    let var_set: HashSet<TypeVarId> = vars.iter().copied().collect();
782    let mut parts = Vec::new();
783    for pred in preds {
784        let ftv = pred.ftv();
785        if ftv.iter().any(|v| var_set.contains(v)) {
786            parts.push(format!("{} {}", pred.class, pred.typ));
787        }
788    }
789    if parts.is_empty() {
790        // Fallback: show all constraints if the filtering logic misses something.
791        for pred in preds {
792            parts.push(format!("{} {}", pred.class, pred.typ));
793        }
794    }
795    parts.join(", ")
796}
797
798fn reject_ambiguous_scheme(scheme: &Scheme) -> Result<(), TypeError> {
799    // Only reject *quantified* ambiguous variables. Variables free in the
800    // environment are allowed to appear only in predicates, since they can be
801    // determined by outer context.
802    let quantified: HashSet<TypeVarId> = scheme.vars.iter().map(|v| v.id).collect();
803    if quantified.is_empty() {
804        return Ok(());
805    }
806
807    let typ_ftv = scheme.typ.ftv();
808    let mut vars = HashSet::new();
809    for pred in &scheme.preds {
810        let TypeKind::Var(tv) = pred.typ.as_ref() else {
811            continue;
812        };
813        if quantified.contains(&tv.id) && !typ_ftv.contains(&tv.id) {
814            vars.insert(tv.id);
815        }
816    }
817
818    if vars.is_empty() {
819        return Ok(());
820    }
821    let mut vars: Vec<TypeVarId> = vars.into_iter().collect();
822    vars.sort_unstable();
823    let constraints = format_constraints_referencing_vars(&scheme.preds, &vars);
824    Err(TypeError::AmbiguousTypeVars { vars, constraints })
825}
826
827fn scheme_compatible(existing: &Scheme, declared: &Scheme) -> bool {
828    let s = match unify(&existing.typ, &declared.typ) {
829        Ok(s) => s,
830        Err(_) => return false,
831    };
832
833    let existing_preds = existing.preds.apply(&s);
834    let declared_preds = declared.preds.apply(&s);
835
836    let mut lhs: Vec<(Symbol, String)> = existing_preds
837        .iter()
838        .map(|p| (p.class.clone(), p.typ.to_string()))
839        .collect();
840    let mut rhs: Vec<(Symbol, String)> = declared_preds
841        .iter()
842        .map(|p| (p.class.clone(), p.typ.to_string()))
843        .collect();
844    lhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
845    rhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
846    lhs == rhs
847}
848
849#[derive(Debug)]
850struct Unifier<'g> {
851    // `subs[id] = Some(t)` means type variable `id` has been bound to `t`.
852    //
853    // This is intentionally a dense `Vec` rather than a `HashMap`: inference
854    // generates `TypeVarId`s from a monotonic counter, so the common case is
855    // “small id space, lots of lookups”. This makes the cost model obvious:
856    // you pay O(max_id) space, and you get O(1) binds/queries.
857    subs: Vec<Option<Type>>,
858    gas: Option<&'g mut GasMeter>,
859    max_infer_depth: Option<usize>,
860    infer_depth: usize,
861}
862
863#[derive(Clone, Copy, Debug)]
864pub struct TypeSystemLimits {
865    pub max_infer_depth: Option<usize>,
866}
867
868impl TypeSystemLimits {
869    pub fn unlimited() -> Self {
870        Self {
871            max_infer_depth: None,
872        }
873    }
874
875    pub fn safe_defaults() -> Self {
876        Self {
877            max_infer_depth: Some(4096),
878        }
879    }
880}
881
882impl Default for TypeSystemLimits {
883    fn default() -> Self {
884        Self::safe_defaults()
885    }
886}
887
888fn superclass_closure(class_env: &ClassEnv, given: &[Predicate]) -> Vec<Predicate> {
889    let mut closure: Vec<Predicate> = given.to_vec();
890    let mut i = 0;
891    while i < closure.len() {
892        let p = closure[i].clone();
893        for sup in class_env.supers_of(&p.class) {
894            closure.push(Predicate::new(sup, p.typ.clone()));
895        }
896        i += 1;
897    }
898    closure
899}
900
901fn check_non_ground_predicates_declared(
902    class_env: &ClassEnv,
903    declared: &[Predicate],
904    inferred: &[Predicate],
905) -> Result<(), TypeError> {
906    // Compare by a stable, user-facing rendering (`Default a`, `Foldable t`, ...),
907    // rather than `TypeVarId`, so signature variables that only appear in
908    // predicates (and thus aren't related by unification) still match up.
909    let closure = superclass_closure(class_env, declared);
910    let closure_keys: HashSet<String> = closure
911        .iter()
912        .map(|p| format!("{} {}", p.class, p.typ))
913        .collect();
914    let mut missing = Vec::new();
915    for pred in inferred {
916        if pred.typ.ftv().is_empty() {
917            continue;
918        }
919        let key = format!("{} {}", pred.class, pred.typ);
920        if !closure_keys.contains(&key) {
921            missing.push(key);
922        }
923    }
924
925    missing.sort();
926    missing.dedup();
927    if missing.is_empty() {
928        return Ok(());
929    }
930    Err(TypeError::MissingConstraints {
931        constraints: missing.join(", "),
932    })
933}
934
935fn type_term_remaining_arity(ty: &Type) -> Option<usize> {
936    match ty.as_ref() {
937        TypeKind::Var(_) => None,
938        TypeKind::Con(tc) => Some(tc.arity),
939        TypeKind::App(l, _) => {
940            let a = type_term_remaining_arity(l)?;
941            Some(a.saturating_sub(1))
942        }
943        TypeKind::Fun(..) | TypeKind::Tuple(..) | TypeKind::Record(..) => Some(0),
944    }
945}
946
947fn max_head_app_arity_for_var(ty: &Type, var_id: TypeVarId) -> usize {
948    let mut max_arity = 0usize;
949    let mut stack: Vec<&Type> = vec![ty];
950    while let Some(t) = stack.pop() {
951        match t.as_ref() {
952            TypeKind::Var(_) | TypeKind::Con(_) => {}
953            TypeKind::App(l, r) => {
954                // Record the full application depth at this node.
955                let mut head = t;
956                let mut args = 0usize;
957                while let TypeKind::App(left, _) = head.as_ref() {
958                    args += 1;
959                    head = left;
960                }
961                if let TypeKind::Var(tv) = head.as_ref()
962                    && tv.id == var_id
963                {
964                    max_arity = max_arity.max(args);
965                }
966                stack.push(l);
967                stack.push(r);
968            }
969            TypeKind::Fun(a, b) => {
970                stack.push(a);
971                stack.push(b);
972            }
973            TypeKind::Tuple(ts) => {
974                for t in ts {
975                    stack.push(t);
976                }
977            }
978            TypeKind::Record(fields) => {
979                for (_, t) in fields {
980                    stack.push(t);
981                }
982            }
983        }
984    }
985    max_arity
986}
987
988impl<'g> Unifier<'g> {
989    fn new(max_infer_depth: Option<usize>) -> Self {
990        Self {
991            subs: Vec::new(),
992            gas: None,
993            max_infer_depth,
994            infer_depth: 0,
995        }
996    }
997
998    fn with_gas(gas: &'g mut GasMeter, max_infer_depth: Option<usize>) -> Self {
999        Self {
1000            subs: Vec::new(),
1001            gas: Some(gas),
1002            max_infer_depth,
1003            infer_depth: 0,
1004        }
1005    }
1006
1007    fn with_infer_depth<T>(
1008        &mut self,
1009        span: Span,
1010        f: impl FnOnce(&mut Self) -> Result<T, TypeError>,
1011    ) -> Result<T, TypeError> {
1012        if let Some(max) = self.max_infer_depth
1013            && self.infer_depth >= max
1014        {
1015            return Err(TypeError::Spanned {
1016                span,
1017                error: Box::new(TypeError::Internal(format!(
1018                    "maximum inference depth exceeded (max {max})"
1019                ))),
1020            });
1021        }
1022        self.infer_depth += 1;
1023        let res = f(self);
1024        self.infer_depth = self.infer_depth.saturating_sub(1);
1025        res
1026    }
1027
1028    fn charge_infer_node(&mut self) -> Result<(), TypeError> {
1029        let Some(gas) = self.gas.as_mut() else {
1030            return Ok(());
1031        };
1032        let cost = gas.costs.infer_node;
1033        gas.charge(cost)?;
1034        Ok(())
1035    }
1036
1037    fn charge_unify_step(&mut self) -> Result<(), TypeError> {
1038        let Some(gas) = self.gas.as_mut() else {
1039            return Ok(());
1040        };
1041        let cost = gas.costs.unify_step;
1042        gas.charge(cost)?;
1043        Ok(())
1044    }
1045
1046    fn bind_var(&mut self, id: TypeVarId, ty: Type) {
1047        if id >= self.subs.len() {
1048            self.subs.resize(id + 1, None);
1049        }
1050        self.subs[id] = Some(ty);
1051    }
1052
1053    fn prune(&mut self, ty: &Type) -> Type {
1054        match ty.as_ref() {
1055            TypeKind::Var(tv) => {
1056                let bound = self.subs.get(tv.id).and_then(|t| t.clone());
1057                match bound {
1058                    Some(bound) => {
1059                        let pruned = self.prune(&bound);
1060                        self.bind_var(tv.id, pruned.clone());
1061                        pruned
1062                    }
1063                    None => ty.clone(),
1064                }
1065            }
1066            TypeKind::Con(_) => ty.clone(),
1067            TypeKind::App(l, r) => {
1068                let l = self.prune(l);
1069                let r = self.prune(r);
1070                Type::app(l, r)
1071            }
1072            TypeKind::Fun(a, b) => {
1073                let a = self.prune(a);
1074                let b = self.prune(b);
1075                Type::fun(a, b)
1076            }
1077            TypeKind::Tuple(ts) => {
1078                Type::new(TypeKind::Tuple(ts.iter().map(|t| self.prune(t)).collect()))
1079            }
1080            TypeKind::Record(fields) => Type::new(TypeKind::Record(
1081                fields
1082                    .iter()
1083                    .map(|(name, ty)| (name.clone(), self.prune(ty)))
1084                    .collect(),
1085            )),
1086        }
1087    }
1088
1089    fn apply_type(&mut self, ty: &Type) -> Type {
1090        self.prune(ty)
1091    }
1092
1093    fn occurs(&mut self, id: TypeVarId, ty: &Type) -> bool {
1094        match self.prune(ty).as_ref() {
1095            TypeKind::Var(tv) => tv.id == id,
1096            TypeKind::Con(_) => false,
1097            TypeKind::App(l, r) => self.occurs(id, l) || self.occurs(id, r),
1098            TypeKind::Fun(a, b) => self.occurs(id, a) || self.occurs(id, b),
1099            TypeKind::Tuple(ts) => ts.iter().any(|t| self.occurs(id, t)),
1100            TypeKind::Record(fields) => fields.iter().any(|(_, ty)| self.occurs(id, ty)),
1101        }
1102    }
1103
1104    fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
1105        self.charge_unify_step()?;
1106        let t1 = self.prune(t1);
1107        let t2 = self.prune(t2);
1108        match (t1.as_ref(), t2.as_ref()) {
1109            (TypeKind::Var(a), TypeKind::Var(b)) if a.id == b.id => Ok(()),
1110            (TypeKind::Var(tv), other) | (other, TypeKind::Var(tv)) => {
1111                if self.occurs(tv.id, &Type::new(other.clone())) {
1112                    Err(TypeError::Occurs(
1113                        tv.id,
1114                        Type::new(other.clone()).to_string(),
1115                    ))
1116                } else {
1117                    self.bind_var(tv.id, Type::new(other.clone()));
1118                    Ok(())
1119                }
1120            }
1121            (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(()),
1122            (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1123                self.unify(l1, l2)?;
1124                self.unify(r1, r2)
1125            }
1126            (TypeKind::Fun(a1, b1), TypeKind::Fun(a2, b2)) => {
1127                self.unify(a1, a2)?;
1128                self.unify(b1, b2)
1129            }
1130            (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1131                if ts1.len() != ts2.len() {
1132                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1133                }
1134                for (a, b) in ts1.iter().zip(ts2.iter()) {
1135                    self.unify(a, b)?;
1136                }
1137                Ok(())
1138            }
1139            (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1140                if f1.len() != f2.len() {
1141                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1142                }
1143                for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1144                    if n1 != n2 {
1145                        return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1146                    }
1147                    self.unify(t1, t2)?;
1148                }
1149                Ok(())
1150            }
1151            (TypeKind::Record(fields), TypeKind::App(head, arg))
1152            | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1153                TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1154                    let elem_ty = record_elem_type_unifier(fields, self)?;
1155                    self.unify(arg, &elem_ty)
1156                }
1157                TypeKind::Var(tv) => {
1158                    self.unify(
1159                        &Type::new(TypeKind::Var(tv.clone())),
1160                        &Type::builtin(BuiltinTypeId::Dict),
1161                    )?;
1162                    let elem_ty = record_elem_type_unifier(fields, self)?;
1163                    self.unify(arg, &elem_ty)
1164                }
1165                _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1166            },
1167            _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1168        }
1169    }
1170
1171    fn into_subst(mut self) -> Subst {
1172        let mut out = Subst::new_sync();
1173        for id in 0..self.subs.len() {
1174            if let Some(ty) = self.subs[id].clone() {
1175                let pruned = self.prune(&ty);
1176                out = out.insert(id, pruned);
1177            }
1178        }
1179        out
1180    }
1181}
1182
1183fn record_elem_type_unifier(
1184    fields: &[(Symbol, Type)],
1185    unifier: &mut Unifier<'_>,
1186) -> Result<Type, TypeError> {
1187    let mut iter = fields.iter();
1188    let first = match iter.next() {
1189        Some((_, ty)) => ty.clone(),
1190        None => return Err(TypeError::UnsupportedExpr("empty record")),
1191    };
1192    for (_, ty) in iter {
1193        unifier.unify(&first, ty)?;
1194    }
1195    Ok(unifier.apply_type(&first))
1196}
1197
1198fn bind(tv: &TypeVar, t: &Type) -> Result<Subst, TypeError> {
1199    if let TypeKind::Var(var) = t.as_ref()
1200        && var.id == tv.id
1201    {
1202        return Ok(Subst::new_sync());
1203    }
1204    if t.ftv().contains(&tv.id) {
1205        Err(TypeError::Occurs(tv.id, t.to_string()))
1206    } else {
1207        Ok(Subst::new_sync().insert(tv.id, t.clone()))
1208    }
1209}
1210
1211fn record_elem_type(fields: &[(Symbol, Type)]) -> Result<(Subst, Type), TypeError> {
1212    let mut iter = fields.iter();
1213    let first = match iter.next() {
1214        Some((_, ty)) => ty.clone(),
1215        None => return Err(TypeError::UnsupportedExpr("empty record")),
1216    };
1217    let mut subst = Subst::new_sync();
1218    let mut current = first;
1219    for (_, ty) in iter {
1220        let s_next = unify(&current.apply(&subst), &ty.apply(&subst))?;
1221        subst = compose_subst(s_next, subst);
1222        current = current.apply(&subst);
1223    }
1224    Ok((subst.clone(), current.apply(&subst)))
1225}
1226
1227/// Compute a most-general unifier for two types.
1228///
1229/// This is the “pure” unifier: it returns an explicit substitution map and is
1230/// easy to read/compose in isolation. The type inference engine uses `Unifier`
1231/// directly to avoid allocating and composing persistent maps at every
1232/// unification step.
1233pub fn unify(t1: &Type, t2: &Type) -> Result<Subst, TypeError> {
1234    match (t1.as_ref(), t2.as_ref()) {
1235        (TypeKind::Fun(l1, r1), TypeKind::Fun(l2, r2)) => {
1236            let s1 = unify(l1, l2)?;
1237            let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1238            Ok(compose_subst(s2, s1))
1239        }
1240        (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1241            if f1.len() != f2.len() {
1242                return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1243            }
1244            let mut subst = Subst::new_sync();
1245            for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1246                if n1 != n2 {
1247                    return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1248                }
1249                let s_next = unify(&t1.apply(&subst), &t2.apply(&subst))?;
1250                subst = compose_subst(s_next, subst);
1251            }
1252            Ok(subst)
1253        }
1254        (TypeKind::Record(fields), TypeKind::App(head, arg))
1255        | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1256            TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1257                let (s_fields, elem_ty) = record_elem_type(fields)?;
1258                let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1259                Ok(compose_subst(s_arg, s_fields))
1260            }
1261            TypeKind::Var(tv) => {
1262                let s_head = bind(tv, &Type::builtin(BuiltinTypeId::Dict))?;
1263                let arg = arg.apply(&s_head);
1264                let (s_fields, elem_ty) = record_elem_type(fields)?;
1265                let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1266                Ok(compose_subst(s_arg, compose_subst(s_fields, s_head)))
1267            }
1268            _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1269        },
1270        (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1271            let s1 = unify(l1, l2)?;
1272            let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1273            Ok(compose_subst(s2, s1))
1274        }
1275        (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1276            if ts1.len() != ts2.len() {
1277                return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1278            }
1279            let mut s = Subst::new_sync();
1280            for (a, b) in ts1.iter().zip(ts2.iter()) {
1281                let s_next = unify(&a.apply(&s), &b.apply(&s))?;
1282                s = compose_subst(s_next, s);
1283            }
1284            Ok(s)
1285        }
1286        (TypeKind::Var(tv), t) | (t, TypeKind::Var(tv)) => bind(tv, &Type::new(t.clone())),
1287        (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(Subst::new_sync()),
1288        _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1289    }
1290}
1291
1292#[derive(Default, Debug, Clone)]
1293pub struct TypeEnv {
1294    pub values: HashTrieMapSync<Symbol, Vec<Scheme>>,
1295}
1296
1297impl TypeEnv {
1298    pub fn new() -> Self {
1299        Self {
1300            values: HashTrieMapSync::new_sync(),
1301        }
1302    }
1303
1304    pub fn extend(&mut self, name: Symbol, scheme: Scheme) {
1305        self.values = self.values.insert(name, vec![scheme]);
1306    }
1307
1308    pub fn extend_overload(&mut self, name: Symbol, scheme: Scheme) {
1309        let mut schemes = self.values.get(&name).cloned().unwrap_or_default();
1310        schemes.push(scheme);
1311        self.values = self.values.insert(name, schemes);
1312    }
1313
1314    pub fn remove(&mut self, name: &Symbol) {
1315        self.values = self.values.remove(name);
1316    }
1317
1318    pub fn lookup(&self, name: &Symbol) -> Option<&[Scheme]> {
1319        self.values.get(name).map(|schemes| schemes.as_slice())
1320    }
1321}
1322
1323impl Types for TypeEnv {
1324    fn apply(&self, s: &Subst) -> Self {
1325        let mut values = HashTrieMapSync::new_sync();
1326        for (k, v) in self.values.iter() {
1327            let updated = v
1328                .iter()
1329                .map(|scheme| {
1330                    // Most schemes in environments are monomorphic. Don't walk
1331                    // and rebuild trees unless we actually have work to do.
1332                    if scheme.vars.is_empty() && !subst_is_empty(s) {
1333                        scheme.apply(s)
1334                    } else {
1335                        scheme.clone()
1336                    }
1337                })
1338                .collect();
1339            values = values.insert(k.clone(), updated);
1340        }
1341        TypeEnv { values }
1342    }
1343
1344    fn ftv(&self) -> HashSet<TypeVarId> {
1345        self.values
1346            .iter()
1347            .flat_map(|(_, schemes)| schemes.iter().flat_map(Types::ftv))
1348            .collect()
1349    }
1350}
1351
1352#[derive(Default, Debug, Clone)]
1353pub struct TypeVarSupply {
1354    counter: TypeVarId,
1355}
1356
1357impl TypeVarSupply {
1358    pub fn new() -> Self {
1359        Self { counter: 0 }
1360    }
1361
1362    pub fn fresh(&mut self, name_hint: impl Into<Option<Symbol>>) -> TypeVar {
1363        let tv = TypeVar::new(self.counter, name_hint.into());
1364        self.counter += 1;
1365        tv
1366    }
1367}
1368
1369fn is_integral_literal_expr(expr: &Expr) -> bool {
1370    matches!(expr, Expr::Int(..) | Expr::Uint(..))
1371}
1372
1373/// Turn a monotype `typ` (plus constraints `preds`) into a polymorphic `Scheme`
1374/// by quantifying over the type variables not free in `env`.
1375pub fn generalize(env: &TypeEnv, preds: Vec<Predicate>, typ: Type) -> Scheme {
1376    let mut vars: Vec<TypeVar> = typ
1377        .ftv()
1378        .union(&preds.ftv())
1379        .copied()
1380        .collect::<HashSet<_>>()
1381        .difference(&env.ftv())
1382        .cloned()
1383        .map(|id| TypeVar::new(id, None))
1384        .collect();
1385    vars.sort_by_key(|v| v.id);
1386    Scheme::new(vars, preds, typ)
1387}
1388
1389pub fn instantiate(scheme: &Scheme, supply: &mut TypeVarSupply) -> (Vec<Predicate>, Type) {
1390    // Instantiate replaces all quantified variables with fresh unification
1391    // variables, preserving the original name as a debugging hint.
1392    let mut subst = Subst::new_sync();
1393    for v in &scheme.vars {
1394        subst = subst.insert(v.id, Type::var(supply.fresh(v.name.clone())));
1395    }
1396    (scheme.preds.apply(&subst), scheme.typ.apply(&subst))
1397}
1398
1399/// A named type parameter for an ADT (e.g. `a` in `List a`).
1400#[derive(Clone, Debug)]
1401pub struct AdtParam {
1402    pub name: Symbol,
1403    pub var: TypeVar,
1404}
1405
1406/// A single ADT variant with zero or more constructor arguments.
1407#[derive(Clone, Debug)]
1408pub struct AdtVariant {
1409    pub name: Symbol,
1410    pub args: Vec<Type>,
1411}
1412
1413/// A type declaration for an algebraic data type.
1414///
1415/// This only describes the *type* surface (params + variants). It does not
1416/// introduce any runtime values by itself. Runtime values are created by
1417/// injecting constructor schemes into the environment (see `inject_adt`).
1418#[derive(Clone, Debug)]
1419pub struct AdtDecl {
1420    pub name: Symbol,
1421    pub params: Vec<AdtParam>,
1422    pub variants: Vec<AdtVariant>,
1423}
1424
1425impl AdtDecl {
1426    pub fn new(name: &Symbol, param_names: &[Symbol], supply: &mut TypeVarSupply) -> Self {
1427        let params = param_names
1428            .iter()
1429            .map(|p| AdtParam {
1430                name: p.clone(),
1431                var: supply.fresh(Some(p.clone())),
1432            })
1433            .collect();
1434        Self {
1435            name: name.clone(),
1436            params,
1437            variants: Vec::new(),
1438        }
1439    }
1440
1441    pub fn param_type(&self, name: &Symbol) -> Option<Type> {
1442        self.params
1443            .iter()
1444            .find(|p| &p.name == name)
1445            .map(|p| Type::var(p.var.clone()))
1446    }
1447
1448    pub fn add_variant(&mut self, name: Symbol, args: Vec<Type>) {
1449        self.variants.push(AdtVariant { name, args });
1450    }
1451
1452    pub fn result_type(&self) -> Type {
1453        let mut ty = Type::con(&self.name, self.params.len());
1454        for param in &self.params {
1455            ty = Type::app(ty, Type::var(param.var.clone()));
1456        }
1457        ty
1458    }
1459
1460    /// Build constructor schemes of the form:
1461    /// `C :: a1 -> a2 -> ... -> T params`.
1462    pub fn constructor_schemes(&self) -> Vec<(Symbol, Scheme)> {
1463        let result_ty = self.result_type();
1464        let vars: Vec<TypeVar> = self.params.iter().map(|p| p.var.clone()).collect();
1465        let mut out = Vec::new();
1466        for variant in &self.variants {
1467            let mut typ = result_ty.clone();
1468            for arg in variant.args.iter().rev() {
1469                typ = Type::fun(arg.clone(), typ);
1470            }
1471            out.push((variant.name.clone(), Scheme::new(vars.clone(), vec![], typ)));
1472        }
1473        out
1474    }
1475}
1476
1477#[derive(Clone, Debug)]
1478pub struct Class {
1479    pub supers: Vec<Symbol>,
1480}
1481
1482impl Class {
1483    pub fn new(supers: Vec<Symbol>) -> Self {
1484        Self { supers }
1485    }
1486}
1487
1488#[derive(Clone, Debug)]
1489pub struct Instance {
1490    pub context: Vec<Predicate>,
1491    pub head: Predicate,
1492}
1493
1494impl Instance {
1495    pub fn new(context: Vec<Predicate>, head: Predicate) -> Self {
1496        Self { context, head }
1497    }
1498}
1499
1500#[derive(Default, Debug, Clone)]
1501pub struct ClassEnv {
1502    pub classes: HashMap<Symbol, Class>,
1503    pub instances: HashMap<Symbol, Vec<Instance>>,
1504}
1505
1506impl ClassEnv {
1507    pub fn new() -> Self {
1508        Self {
1509            classes: HashMap::new(),
1510            instances: HashMap::new(),
1511        }
1512    }
1513
1514    pub fn add_class(&mut self, name: Symbol, supers: Vec<Symbol>) {
1515        self.classes.insert(name, Class::new(supers));
1516    }
1517
1518    pub fn add_instance(&mut self, class: Symbol, inst: Instance) {
1519        self.instances.entry(class).or_default().push(inst);
1520    }
1521
1522    pub fn supers_of(&self, class: &Symbol) -> Vec<Symbol> {
1523        self.classes
1524            .get(class)
1525            .map(|c| c.supers.clone())
1526            .unwrap_or_default()
1527    }
1528}
1529
1530pub fn entails(
1531    class_env: &ClassEnv,
1532    given: &[Predicate],
1533    pred: &Predicate,
1534) -> Result<bool, TypeError> {
1535    // Expand given with superclasses.
1536    let mut closure: Vec<Predicate> = given.to_vec();
1537    let mut i = 0;
1538    while i < closure.len() {
1539        let p = closure[i].clone();
1540        for sup in class_env.supers_of(&p.class) {
1541            closure.push(Predicate::new(sup, p.typ.clone()));
1542        }
1543        i += 1;
1544    }
1545
1546    if closure
1547        .iter()
1548        .any(|p| p.class == pred.class && p.typ == pred.typ)
1549    {
1550        return Ok(true);
1551    }
1552
1553    if !class_env.classes.contains_key(&pred.class) {
1554        return Err(TypeError::UnknownClass(pred.class.clone()));
1555    }
1556
1557    if let Some(instances) = class_env.instances.get(&pred.class) {
1558        for inst in instances {
1559            if let Ok(s) = unify(&inst.head.typ, &pred.typ) {
1560                let ctx = inst.context.apply(&s);
1561                if ctx
1562                    .iter()
1563                    .all(|c| entails(class_env, &closure, c).unwrap_or(false))
1564                {
1565                    return Ok(true);
1566                }
1567            }
1568        }
1569    }
1570    Ok(false)
1571}
1572
1573#[derive(Default, Debug, Clone)]
1574pub struct TypeSystem {
1575    pub env: TypeEnv,
1576    pub classes: ClassEnv,
1577    pub adts: HashMap<Symbol, AdtDecl>,
1578    pub class_info: HashMap<Symbol, ClassInfo>,
1579    pub class_methods: HashMap<Symbol, ClassMethodInfo>,
1580    /// Names introduced by `declare fn` (forward declarations).
1581    ///
1582    /// These are placeholders in the type environment and must not block a later
1583    /// real definition (e.g. `fn foo = ...` or host/CLI injection).
1584    pub declared_values: HashSet<Symbol>,
1585    pub supply: TypeVarSupply,
1586    limits: TypeSystemLimits,
1587}
1588
1589/// Semantic information about a type class declaration, derived from Rex source.
1590///
1591/// Design notes (WARM):
1592/// - We keep this explicit and data-oriented: it makes review easy and keeps costs visible.
1593/// - Rex represents multi-parameter classes by encoding the parameters as a tuple in the
1594///   single `Predicate.typ` slot. For a unary class `C a` the predicate is `C a`. For a
1595///   binary class `C t a` the predicate is `C (t, a)`, etc.
1596/// - This keeps the runtime/type-inference machinery simple: instance matching is still
1597///   “unify the predicate types”, and no separate arity tracking is needed.
1598#[derive(Clone, Debug)]
1599pub struct ClassInfo {
1600    pub name: Symbol,
1601    pub params: Vec<Symbol>,
1602    pub supers: Vec<Symbol>,
1603    pub methods: BTreeMap<Symbol, Scheme>,
1604}
1605
1606#[derive(Clone, Debug)]
1607pub struct ClassMethodInfo {
1608    pub class: Symbol,
1609    pub scheme: Scheme,
1610}
1611
1612#[derive(Clone, Debug)]
1613pub struct PreparedInstanceDecl {
1614    pub span: Span,
1615    pub class: Symbol,
1616    pub head: Type,
1617    pub context: Vec<Predicate>,
1618}
1619
1620impl TypeSystem {
1621    pub fn new() -> Self {
1622        Self {
1623            env: TypeEnv::new(),
1624            classes: ClassEnv::new(),
1625            adts: HashMap::new(),
1626            class_info: HashMap::new(),
1627            class_methods: HashMap::new(),
1628            declared_values: HashSet::new(),
1629            supply: TypeVarSupply::new(),
1630            limits: TypeSystemLimits::default(),
1631        }
1632    }
1633
1634    pub fn fresh_type_var(&mut self, name: Option<Symbol>) -> TypeVar {
1635        self.supply.fresh(name)
1636    }
1637
1638    pub fn set_limits(&mut self, limits: TypeSystemLimits) {
1639        self.limits = limits;
1640    }
1641
1642    pub fn new_with_prelude() -> Result<Self, TypeError> {
1643        let mut ts = TypeSystem::new();
1644        prelude::build_prelude(&mut ts)?;
1645        Ok(ts)
1646    }
1647
1648    fn register_decl(&mut self, decl: &Decl) -> Result<(), TypeError> {
1649        match decl {
1650            Decl::Type(ty) => self.register_type_decl(ty),
1651            Decl::Class(class_decl) => self.register_class_decl(class_decl),
1652            Decl::Instance(inst_decl) => {
1653                let _ = self.register_instance_decl(inst_decl)?;
1654                Ok(())
1655            }
1656            Decl::Fn(fd) => self.register_fn_decls(std::slice::from_ref(fd)),
1657            Decl::DeclareFn(fd) => self.inject_declare_fn_decl(fd),
1658            Decl::Import(..) => Ok(()),
1659        }
1660    }
1661
1662    pub fn register_decls(&mut self, decls: &[Decl]) -> Result<(), TypeError> {
1663        let mut pending_fns: Vec<FnDecl> = Vec::new();
1664        for decl in decls {
1665            if let Decl::Fn(fd) = decl {
1666                pending_fns.push(fd.clone());
1667                continue;
1668            }
1669
1670            if !pending_fns.is_empty() {
1671                self.register_fn_decls(&pending_fns)?;
1672                pending_fns.clear();
1673            }
1674
1675            self.register_decl(decl)?;
1676        }
1677        if !pending_fns.is_empty() {
1678            self.register_fn_decls(&pending_fns)?;
1679        }
1680        Ok(())
1681    }
1682
1683    pub fn add_value(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1684        let name = sym(name.as_ref());
1685        self.declared_values.remove(&name);
1686        self.env.extend(name, scheme);
1687    }
1688
1689    pub fn add_overload(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1690        let name = sym(name.as_ref());
1691        self.declared_values.remove(&name);
1692        self.env.extend_overload(name, scheme);
1693    }
1694
1695    pub fn register_instance(&mut self, class: impl AsRef<str>, inst: Instance) {
1696        self.classes.add_instance(sym(class.as_ref()), inst);
1697    }
1698
1699    pub fn register_class_decl(&mut self, decl: &ClassDecl) -> Result<(), TypeError> {
1700        let span = decl.span;
1701        (|| {
1702            // Classes are global, and Rex does not support reopening/merging them.
1703            // Allowing that would be a long-term maintenance hazard: it creates
1704            // spooky-action-at-a-distance across modules and makes reviews harder.
1705            if self.class_info.contains_key(&decl.name)
1706                || self.classes.classes.contains_key(&decl.name)
1707            {
1708                return Err(TypeError::DuplicateClass(decl.name.clone()));
1709            }
1710            if decl.params.is_empty() {
1711                return Err(TypeError::InvalidClassArity {
1712                    class: decl.name.clone(),
1713                    got: decl.params.len(),
1714                });
1715            }
1716            let params = decl.params.clone();
1717
1718            // Register the superclass relationships in the class environment.
1719            //
1720            // We only accept `<= C param` style superclasses for now. Anything
1721            // fancier would require storing type-level relationships in `ClassEnv`,
1722            // which Rex does not currently model.
1723            let mut supers = Vec::with_capacity(decl.supers.len());
1724            if !decl.supers.is_empty() && params.len() != 1 {
1725                return Err(TypeError::UnsupportedExpr(
1726                    "multi-parameter classes cannot declare superclasses yet",
1727                ));
1728            }
1729            for sup in &decl.supers {
1730                let mut vars = HashMap::new();
1731                let param = params[0].clone();
1732                let param_tv = self.supply.fresh(Some(param.clone()));
1733                vars.insert(param, param_tv.clone());
1734                let sup_ty = type_from_annotation_expr_vars(
1735                    &self.adts,
1736                    &sup.typ,
1737                    &mut vars,
1738                    &mut self.supply,
1739                )?;
1740                if sup_ty != Type::var(param_tv) {
1741                    return Err(TypeError::UnsupportedExpr(
1742                        "superclass constraints must be of the form `<= C a`",
1743                    ));
1744                }
1745                supers.push(sup.class.to_dotted_symbol());
1746            }
1747
1748            self.classes.add_class(decl.name.clone(), supers.clone());
1749
1750            let mut methods = BTreeMap::new();
1751            for ClassMethodSig { name, typ } in &decl.methods {
1752                if self.env.lookup(name).is_some() || self.class_methods.contains_key(name) {
1753                    return Err(TypeError::DuplicateClassMethod(name.clone()));
1754                }
1755
1756                let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1757                let mut param_tvs: Vec<TypeVar> = Vec::with_capacity(params.len());
1758                for param in &params {
1759                    let tv = self.supply.fresh(Some(param.clone()));
1760                    vars.insert(param.clone(), tv.clone());
1761                    param_tvs.push(tv);
1762                }
1763
1764                let ty =
1765                    type_from_annotation_expr_vars(&self.adts, typ, &mut vars, &mut self.supply)?;
1766
1767                let mut scheme_vars: Vec<TypeVar> = vars.values().cloned().collect();
1768                scheme_vars.sort_by_key(|tv| tv.id);
1769                scheme_vars.dedup_by_key(|tv| tv.id);
1770
1771                let class_pred = Predicate {
1772                    class: decl.name.clone(),
1773                    typ: if param_tvs.len() == 1 {
1774                        Type::var(param_tvs[0].clone())
1775                    } else {
1776                        Type::tuple(param_tvs.into_iter().map(Type::var).collect())
1777                    },
1778                };
1779                let scheme = Scheme::new(scheme_vars, vec![class_pred], ty);
1780
1781                self.env.extend(name.clone(), scheme.clone());
1782                self.class_methods.insert(
1783                    name.clone(),
1784                    ClassMethodInfo {
1785                        class: decl.name.clone(),
1786                        scheme: scheme.clone(),
1787                    },
1788                );
1789                methods.insert(name.clone(), scheme);
1790            }
1791
1792            self.class_info.insert(
1793                decl.name.clone(),
1794                ClassInfo {
1795                    name: decl.name.clone(),
1796                    params,
1797                    supers,
1798                    methods,
1799                },
1800            );
1801            Ok(())
1802        })()
1803        .map_err(|err| with_span(&span, err))
1804    }
1805
1806    pub fn register_instance_decl(
1807        &mut self,
1808        decl: &InstanceDecl,
1809    ) -> Result<PreparedInstanceDecl, TypeError> {
1810        let span = decl.span;
1811        (|| {
1812            let class = decl.class.clone();
1813            if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
1814                return Err(TypeError::UnknownClass(class));
1815            }
1816
1817            let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1818            let head = type_from_annotation_expr_vars(
1819                &self.adts,
1820                &decl.head,
1821                &mut vars,
1822                &mut self.supply,
1823            )?;
1824            let context = predicates_from_constraints(
1825                &self.adts,
1826                &decl.context,
1827                &mut vars,
1828                &mut self.supply,
1829            )?;
1830
1831            let inst = Instance::new(
1832                context.clone(),
1833                Predicate {
1834                    class: decl.class.clone(),
1835                    typ: head.clone(),
1836                },
1837            );
1838
1839            // Validate method list against the class declaration if present.
1840            if let Some(info) = self.class_info.get(&decl.class) {
1841                for method in &decl.methods {
1842                    if !info.methods.contains_key(&method.name) {
1843                        return Err(TypeError::UnknownInstanceMethod {
1844                            class: decl.class.clone(),
1845                            method: method.name.clone(),
1846                        });
1847                    }
1848                }
1849                for method_name in info.methods.keys() {
1850                    if !decl.methods.iter().any(|m| &m.name == method_name) {
1851                        return Err(TypeError::MissingInstanceMethod {
1852                            class: decl.class.clone(),
1853                            method: method_name.clone(),
1854                        });
1855                    }
1856                }
1857            }
1858
1859            self.classes.add_instance(decl.class.clone(), inst);
1860            Ok(PreparedInstanceDecl {
1861                span,
1862                class: decl.class.clone(),
1863                head,
1864                context,
1865            })
1866        })()
1867        .map_err(|err| with_span(&span, err))
1868    }
1869
1870    pub fn prepare_instance_decl(
1871        &mut self,
1872        decl: &InstanceDecl,
1873    ) -> Result<PreparedInstanceDecl, TypeError> {
1874        let span = decl.span;
1875        (|| {
1876            let class = decl.class.clone();
1877            if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
1878                return Err(TypeError::UnknownClass(class));
1879            }
1880
1881            let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1882            let head = type_from_annotation_expr_vars(
1883                &self.adts,
1884                &decl.head,
1885                &mut vars,
1886                &mut self.supply,
1887            )?;
1888            let context = predicates_from_constraints(
1889                &self.adts,
1890                &decl.context,
1891                &mut vars,
1892                &mut self.supply,
1893            )?;
1894
1895            // Validate method list against the class declaration if present.
1896            if let Some(info) = self.class_info.get(&decl.class) {
1897                for method in &decl.methods {
1898                    if !info.methods.contains_key(&method.name) {
1899                        return Err(TypeError::UnknownInstanceMethod {
1900                            class: decl.class.clone(),
1901                            method: method.name.clone(),
1902                        });
1903                    }
1904                }
1905                for method_name in info.methods.keys() {
1906                    if !decl.methods.iter().any(|m| &m.name == method_name) {
1907                        return Err(TypeError::MissingInstanceMethod {
1908                            class: decl.class.clone(),
1909                            method: method_name.clone(),
1910                        });
1911                    }
1912                }
1913            }
1914
1915            Ok(PreparedInstanceDecl {
1916                span,
1917                class: decl.class.clone(),
1918                head,
1919                context,
1920            })
1921        })()
1922        .map_err(|err| with_span(&span, err))
1923    }
1924
1925    pub fn register_fn_decls(&mut self, decls: &[FnDecl]) -> Result<(), TypeError> {
1926        if decls.is_empty() {
1927            return Ok(());
1928        }
1929
1930        let saved_env = self.env.clone();
1931        let saved_declared = self.declared_values.clone();
1932
1933        let result: Result<(), TypeError> = (|| {
1934            #[derive(Clone)]
1935            struct FnInfo {
1936                decl: FnDecl,
1937                expected: Type,
1938                declared_preds: Vec<Predicate>,
1939                scheme: Scheme,
1940                ann_vars: HashMap<Symbol, TypeVar>,
1941            }
1942
1943            let mut infos: Vec<FnInfo> = Vec::with_capacity(decls.len());
1944            let mut seen_names = HashSet::new();
1945
1946            for decl in decls {
1947                let span = decl.span;
1948                let info = (|| {
1949                    let name = &decl.name.name;
1950                    if !seen_names.insert(name.clone()) {
1951                        return Err(TypeError::DuplicateValue(name.clone()));
1952                    }
1953
1954                    if self.env.lookup(name).is_some() {
1955                        if self.declared_values.remove(name) {
1956                            // A forward declaration should not block the real definition.
1957                            self.env.remove(name);
1958                        } else {
1959                            return Err(TypeError::DuplicateValue(name.clone()));
1960                        }
1961                    }
1962
1963                    let mut sig = decl.ret.clone();
1964                    for (_, ann) in decl.params.iter().rev() {
1965                        let span = Span::from_begin_end(ann.span().begin, sig.span().end);
1966                        sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
1967                    }
1968
1969                    let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
1970                    let expected = type_from_annotation_expr_vars(
1971                        &self.adts,
1972                        &sig,
1973                        &mut ann_vars,
1974                        &mut self.supply,
1975                    )?;
1976                    let declared_preds = predicates_from_constraints(
1977                        &self.adts,
1978                        &decl.constraints,
1979                        &mut ann_vars,
1980                        &mut self.supply,
1981                    )?;
1982
1983                    // Validate that declared constraints are well-formed.
1984                    let var_arities: HashMap<TypeVarId, usize> = ann_vars
1985                        .values()
1986                        .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
1987                        .collect();
1988                    for pred in &declared_preds {
1989                        let _ = entails(&self.classes, &[], pred)?;
1990                        let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
1991                        else {
1992                            continue;
1993                        };
1994                        let args: Vec<Type> = if expected_arities.len() == 1 {
1995                            vec![pred.typ.clone()]
1996                        } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
1997                            if parts.len() != expected_arities.len() {
1998                                continue;
1999                            }
2000                            parts.clone()
2001                        } else {
2002                            continue;
2003                        };
2004
2005                        for (arg, expected_arity) in
2006                            args.iter().zip(expected_arities.iter().copied())
2007                        {
2008                            let got =
2009                                type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2010                                    TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2011                                    _ => None,
2012                                });
2013                            let Some(got) = got else {
2014                                continue;
2015                            };
2016                            if got != expected_arity {
2017                                return Err(TypeError::KindMismatch {
2018                                    class: pred.class.clone(),
2019                                    expected: expected_arity,
2020                                    got,
2021                                    typ: arg.to_string(),
2022                                });
2023                            }
2024                        }
2025                    }
2026
2027                    let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2028                    vars.sort_by_key(|v| v.id);
2029                    let scheme = Scheme::new(vars, declared_preds.clone(), expected.clone());
2030                    reject_ambiguous_scheme(&scheme)?;
2031
2032                    Ok(FnInfo {
2033                        decl: decl.clone(),
2034                        expected,
2035                        declared_preds,
2036                        scheme,
2037                        ann_vars,
2038                    })
2039                })();
2040
2041                infos.push(info.map_err(|err| with_span(&span, err))?);
2042            }
2043
2044            // Seed environment with all declared signatures first so fn bodies
2045            // can reference each other recursively (let-rec semantics).
2046            for info in &infos {
2047                self.env
2048                    .extend(info.decl.name.name.clone(), info.scheme.clone());
2049            }
2050
2051            for info in infos {
2052                let span = info.decl.span;
2053                let mut lam_body = info.decl.body.clone();
2054                let mut lam_end = lam_body.span().end;
2055                for (param, ann) in info.decl.params.iter().rev() {
2056                    let lam_constraints = Vec::new();
2057                    let span = Span::from_begin_end(param.span.begin, lam_end);
2058                    lam_body = Arc::new(Expr::Lam(
2059                        span,
2060                        Scope::new_sync(),
2061                        param.clone(),
2062                        Some(ann.clone()),
2063                        lam_constraints,
2064                        lam_body,
2065                    ));
2066                    lam_end = lam_body.span().end;
2067                }
2068
2069                let (typed, preds, inferred) = infer_typed(self, lam_body.as_ref())?;
2070                let s = unify(&inferred, &info.expected)?;
2071                let preds = preds.apply(&s);
2072                let inferred = inferred.apply(&s);
2073                let declared_preds = info.declared_preds.apply(&s);
2074                let expected = info.expected.apply(&s);
2075
2076                // Keep kind checks aligned with existing `inject_fn_decl` logic.
2077                let var_arities: HashMap<TypeVarId, usize> = info
2078                    .ann_vars
2079                    .values()
2080                    .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
2081                    .collect();
2082                for pred in &declared_preds {
2083                    let _ = entails(&self.classes, &[], pred)?;
2084                    let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
2085                    else {
2086                        continue;
2087                    };
2088                    let args: Vec<Type> = if expected_arities.len() == 1 {
2089                        vec![pred.typ.clone()]
2090                    } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2091                        if parts.len() != expected_arities.len() {
2092                            continue;
2093                        }
2094                        parts.clone()
2095                    } else {
2096                        continue;
2097                    };
2098
2099                    for (arg, expected_arity) in args.iter().zip(expected_arities.iter().copied()) {
2100                        let got = type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2101                            TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2102                            _ => None,
2103                        });
2104                        let Some(got) = got else {
2105                            continue;
2106                        };
2107                        if got != expected_arity {
2108                            return Err(with_span(
2109                                &span,
2110                                TypeError::KindMismatch {
2111                                    class: pred.class.clone(),
2112                                    expected: expected_arity,
2113                                    got,
2114                                    typ: arg.to_string(),
2115                                },
2116                            ));
2117                        }
2118                    }
2119                }
2120
2121                check_non_ground_predicates_declared(&self.classes, &declared_preds, &preds)
2122                    .map_err(|err| with_span(&span, err))?;
2123
2124                let _ = inferred;
2125                let _ = typed;
2126            }
2127
2128            Ok(())
2129        })();
2130
2131        if result.is_err() {
2132            self.env = saved_env;
2133            self.declared_values = saved_declared;
2134        }
2135        result
2136    }
2137
2138    pub fn inject_declare_fn_decl(&mut self, decl: &DeclareFnDecl) -> Result<(), TypeError> {
2139        let span = decl.span;
2140        (|| {
2141            // Build the declared signature type.
2142            let mut sig = decl.ret.clone();
2143            for (_, ann) in decl.params.iter().rev() {
2144                let span = Span::from_begin_end(ann.span().begin, sig.span().end);
2145                sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
2146            }
2147
2148            let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
2149            let expected =
2150                type_from_annotation_expr_vars(&self.adts, &sig, &mut ann_vars, &mut self.supply)?;
2151            let declared_preds = predicates_from_constraints(
2152                &self.adts,
2153                &decl.constraints,
2154                &mut ann_vars,
2155                &mut self.supply,
2156            )?;
2157
2158            let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2159            vars.sort_by_key(|v| v.id);
2160            let scheme = Scheme::new(vars, declared_preds, expected);
2161            reject_ambiguous_scheme(&scheme)?;
2162
2163            // Validate referenced classes exist (and are spelled correctly).
2164            for pred in &scheme.preds {
2165                let _ = entails(&self.classes, &[], pred)?;
2166            }
2167
2168            let name = &decl.name.name;
2169
2170            // If there is already a real definition (prelude/host/`fn`), treat
2171            // `declare fn` as documentation only and ignore it.
2172            if self.env.lookup(name).is_some() && !self.declared_values.contains(name) {
2173                return Ok(());
2174            }
2175
2176            if let Some(existing) = self.env.lookup(name) {
2177                if existing.iter().any(|s| scheme_compatible(s, &scheme)) {
2178                    return Ok(());
2179                }
2180                return Err(TypeError::DuplicateValue(decl.name.name.clone()));
2181            }
2182
2183            self.env.extend(decl.name.name.clone(), scheme);
2184            self.declared_values.insert(decl.name.name.clone());
2185            Ok(())
2186        })()
2187        .map_err(|err| with_span(&span, err))
2188    }
2189
2190    pub fn instantiate_class_method_for_head(
2191        &mut self,
2192        class: &Symbol,
2193        method: &Symbol,
2194        head: &Type,
2195    ) -> Result<Type, TypeError> {
2196        let info = self
2197            .class_info
2198            .get(class)
2199            .ok_or_else(|| TypeError::UnknownClass(class.clone()))?;
2200        let scheme = info
2201            .methods
2202            .get(method)
2203            .ok_or_else(|| TypeError::UnknownInstanceMethod {
2204                class: class.clone(),
2205                method: method.clone(),
2206            })?;
2207
2208        let (preds, typ) = instantiate(scheme, &mut self.supply);
2209        let class_pred =
2210            preds
2211                .iter()
2212                .find(|p| &p.class == class)
2213                .ok_or(TypeError::UnsupportedExpr(
2214                    "class method scheme missing class predicate",
2215                ))?;
2216        let s = unify(&class_pred.typ, head)?;
2217        Ok(typ.apply(&s))
2218    }
2219
2220    pub fn typecheck_instance_method(
2221        &mut self,
2222        prepared: &PreparedInstanceDecl,
2223        method: &InstanceMethodImpl,
2224    ) -> Result<TypedExpr, TypeError> {
2225        let expected =
2226            self.instantiate_class_method_for_head(&prepared.class, &method.name, &prepared.head)?;
2227        let (typed, preds, actual) = infer_typed(self, method.body.as_ref())?;
2228        let s = unify(&actual, &expected)?;
2229        let typed = typed.apply(&s);
2230        let preds = preds.apply(&s);
2231
2232        // The only legal “given” constraints inside an instance method are the
2233        // instance context (plus superclass closure, plus the instance head
2234        // itself). We do *not* allow instance
2235        // search for non-ground constraints here, because that would be unsound:
2236        // a type variable would unify with any concrete instance head.
2237        let mut given = prepared.context.clone();
2238
2239        // Allow recursive instance methods (e.g. `Eq (List a)` calling `(==)`
2240        // on the tail). This is dictionary recursion, not instance search.
2241        given.push(Predicate::new(
2242            prepared.class.clone(),
2243            prepared.head.clone(),
2244        ));
2245        let mut i = 0;
2246        while i < given.len() {
2247            let p = given[i].clone();
2248            for sup in self.classes.supers_of(&p.class) {
2249                given.push(Predicate::new(sup, p.typ.clone()));
2250            }
2251            i += 1;
2252        }
2253
2254        for pred in &preds {
2255            if pred.typ.ftv().is_empty() {
2256                if !entails(&self.classes, &given, pred)? {
2257                    return Err(TypeError::NoInstance(
2258                        pred.class.clone(),
2259                        pred.typ.to_string(),
2260                    ));
2261                }
2262            } else if !given
2263                .iter()
2264                .any(|p| p.class == pred.class && p.typ == pred.typ)
2265            {
2266                return Err(TypeError::MissingInstanceConstraint {
2267                    method: method.name.clone(),
2268                    class: pred.class.clone(),
2269                    typ: pred.typ.to_string(),
2270                });
2271            }
2272        }
2273
2274        Ok(typed)
2275    }
2276
2277    /// Register constructor schemes for an ADT in the type environment.
2278    /// This makes constructors (e.g. `Some`, `None`, `Ok`, `Err`) available
2279    /// to the type checker as normal values.
2280    pub fn register_adt(&mut self, adt: &AdtDecl) {
2281        self.adts.insert(adt.name.clone(), adt.clone());
2282        for (name, scheme) in adt.constructor_schemes() {
2283            self.register_value_scheme(&name, scheme);
2284        }
2285    }
2286
2287    pub fn adt_from_decl(&mut self, decl: &TypeDecl) -> Result<AdtDecl, TypeError> {
2288        let mut adt = AdtDecl::new(&decl.name, &decl.params, &mut self.supply);
2289        let mut param_map: HashMap<Symbol, TypeVar> = HashMap::new();
2290        for param in &adt.params {
2291            param_map.insert(param.name.clone(), param.var.clone());
2292        }
2293
2294        for variant in &decl.variants {
2295            let mut args = Vec::new();
2296            for arg in &variant.args {
2297                let ty = self.type_from_expr(decl, &param_map, arg)?;
2298                args.push(ty);
2299            }
2300            adt.add_variant(variant.name.clone(), args);
2301        }
2302        Ok(adt)
2303    }
2304
2305    pub fn register_type_decl(&mut self, decl: &TypeDecl) -> Result<(), TypeError> {
2306        if BuiltinTypeId::from_symbol(&decl.name).is_some() {
2307            return Err(TypeError::ReservedTypeName(decl.name.clone()));
2308        }
2309        let adt = self.adt_from_decl(decl)?;
2310        self.register_adt(&adt);
2311        Ok(())
2312    }
2313
2314    fn type_from_expr(
2315        &mut self,
2316        decl: &TypeDecl,
2317        params: &HashMap<Symbol, TypeVar>,
2318        expr: &TypeExpr,
2319    ) -> Result<Type, TypeError> {
2320        let span = *expr.span();
2321        let res = (|| match expr {
2322            TypeExpr::Name(_, name) => {
2323                let name_sym = name.to_dotted_symbol();
2324                if let Some(tv) = params.get(&name_sym) {
2325                    Ok(Type::var(tv.clone()))
2326                } else {
2327                    let name = normalize_type_name(&name_sym);
2328                    if let Some(arity) = self.type_arity(decl, &name) {
2329                        Ok(Type::con(name, arity))
2330                    } else {
2331                        Err(TypeError::UnknownTypeName(name))
2332                    }
2333                }
2334            }
2335            TypeExpr::App(_, fun, arg) => {
2336                let fty = self.type_from_expr(decl, params, fun)?;
2337                let aty = self.type_from_expr(decl, params, arg)?;
2338                Ok(type_app_with_result_syntax(fty, aty))
2339            }
2340            TypeExpr::Fun(_, arg, ret) => {
2341                let arg_ty = self.type_from_expr(decl, params, arg)?;
2342                let ret_ty = self.type_from_expr(decl, params, ret)?;
2343                Ok(Type::fun(arg_ty, ret_ty))
2344            }
2345            TypeExpr::Tuple(_, elems) => {
2346                let mut out = Vec::new();
2347                for elem in elems {
2348                    out.push(self.type_from_expr(decl, params, elem)?);
2349                }
2350                Ok(Type::tuple(out))
2351            }
2352            TypeExpr::Record(_, fields) => {
2353                let mut out = Vec::new();
2354                for (name, ty) in fields {
2355                    out.push((name.clone(), self.type_from_expr(decl, params, ty)?));
2356                }
2357                Ok(Type::record(out))
2358            }
2359        })();
2360        res.map_err(|err| with_span(&span, err))
2361    }
2362
2363    fn type_arity(&self, decl: &TypeDecl, name: &Symbol) -> Option<usize> {
2364        if &decl.name == name {
2365            return Some(decl.params.len());
2366        }
2367        if let Some(adt) = self.adts.get(name) {
2368            return Some(adt.params.len());
2369        }
2370        BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2371    }
2372
2373    fn register_value_scheme(&mut self, name: &Symbol, scheme: Scheme) {
2374        match self.env.lookup(name) {
2375            None => self.env.extend(name.clone(), scheme),
2376            Some(existing) => {
2377                if existing.iter().any(|s| unify(&s.typ, &scheme.typ).is_ok()) {
2378                    return;
2379                }
2380                self.env.extend_overload(name.clone(), scheme);
2381            }
2382        }
2383    }
2384
2385    fn expected_class_param_arities(&self, class: &Symbol) -> Option<Vec<usize>> {
2386        let info = self.class_info.get(class)?;
2387        let mut out = vec![0usize; info.params.len()];
2388        for scheme in info.methods.values() {
2389            for (idx, param) in info.params.iter().enumerate() {
2390                let Some(tv) = scheme.vars.iter().find(|v| v.name.as_ref() == Some(param)) else {
2391                    continue;
2392                };
2393                out[idx] = out[idx].max(max_head_app_arity_for_var(&scheme.typ, tv.id));
2394            }
2395        }
2396        Some(out)
2397    }
2398
2399    fn check_predicate_kind(&self, pred: &Predicate) -> Result<(), TypeError> {
2400        let Some(expected) = self.expected_class_param_arities(&pred.class) else {
2401            // Host-injected classes (via Rust API) won't have `class_info`.
2402            return Ok(());
2403        };
2404
2405        let args: Vec<Type> = if expected.len() == 1 {
2406            vec![pred.typ.clone()]
2407        } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2408            if parts.len() != expected.len() {
2409                return Ok(());
2410            }
2411            parts.clone()
2412        } else {
2413            return Ok(());
2414        };
2415
2416        for (arg, expected_arity) in args.iter().zip(expected.iter().copied()) {
2417            let Some(got) = type_term_remaining_arity(arg) else {
2418                // If we can't determine the arity (e.g. a bare type var), skip:
2419                // call sites may fix it up, and Rex does not currently do full
2420                // kind inference.
2421                continue;
2422            };
2423            if got != expected_arity {
2424                return Err(TypeError::KindMismatch {
2425                    class: pred.class.clone(),
2426                    expected: expected_arity,
2427                    got,
2428                    typ: arg.to_string(),
2429                });
2430            }
2431        }
2432        Ok(())
2433    }
2434
2435    fn check_predicate_kinds(&self, preds: &[Predicate]) -> Result<(), TypeError> {
2436        for pred in preds {
2437            self.check_predicate_kind(pred)?;
2438        }
2439        Ok(())
2440    }
2441}
2442
2443fn type_from_annotation_expr(
2444    adts: &HashMap<Symbol, AdtDecl>,
2445    expr: &TypeExpr,
2446) -> Result<Type, TypeError> {
2447    let span = *expr.span();
2448    let res = (|| match expr {
2449        TypeExpr::Name(_, name) => {
2450            let name = normalize_type_name(&name.to_dotted_symbol());
2451            match annotation_type_arity(adts, &name) {
2452                Some(arity) => Ok(Type::con(name, arity)),
2453                None => Err(TypeError::UnknownTypeName(name)),
2454            }
2455        }
2456        TypeExpr::App(_, fun, arg) => {
2457            let fty = type_from_annotation_expr(adts, fun)?;
2458            let aty = type_from_annotation_expr(adts, arg)?;
2459            Ok(type_app_with_result_syntax(fty, aty))
2460        }
2461        TypeExpr::Fun(_, arg, ret) => {
2462            let arg_ty = type_from_annotation_expr(adts, arg)?;
2463            let ret_ty = type_from_annotation_expr(adts, ret)?;
2464            Ok(Type::fun(arg_ty, ret_ty))
2465        }
2466        TypeExpr::Tuple(_, elems) => {
2467            let mut out = Vec::new();
2468            for elem in elems {
2469                out.push(type_from_annotation_expr(adts, elem)?);
2470            }
2471            Ok(Type::tuple(out))
2472        }
2473        TypeExpr::Record(_, fields) => {
2474            let mut out = Vec::new();
2475            for (name, ty) in fields {
2476                out.push((name.clone(), type_from_annotation_expr(adts, ty)?));
2477            }
2478            Ok(Type::record(out))
2479        }
2480    })();
2481    res.map_err(|err| with_span(&span, err))
2482}
2483
2484fn type_from_annotation_expr_vars(
2485    adts: &HashMap<Symbol, AdtDecl>,
2486    expr: &TypeExpr,
2487    vars: &mut HashMap<Symbol, TypeVar>,
2488    supply: &mut TypeVarSupply,
2489) -> Result<Type, TypeError> {
2490    let span = *expr.span();
2491    let res = (|| match expr {
2492        TypeExpr::Name(_, name) => {
2493            let name = normalize_type_name(&name.to_dotted_symbol());
2494            if let Some(arity) = annotation_type_arity(adts, &name) {
2495                Ok(Type::con(name, arity))
2496            } else if let Some(tv) = vars.get(&name) {
2497                Ok(Type::var(tv.clone()))
2498            } else {
2499                let is_upper = name
2500                    .chars()
2501                    .next()
2502                    .map(|c| c.is_uppercase())
2503                    .unwrap_or(false);
2504                if is_upper {
2505                    return Err(TypeError::UnknownTypeName(name));
2506                }
2507                let tv = supply.fresh(Some(name.clone()));
2508                vars.insert(name.clone(), tv.clone());
2509                Ok(Type::var(tv))
2510            }
2511        }
2512        TypeExpr::App(_, fun, arg) => {
2513            let fty = type_from_annotation_expr_vars(adts, fun, vars, supply)?;
2514            let aty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2515            Ok(type_app_with_result_syntax(fty, aty))
2516        }
2517        TypeExpr::Fun(_, arg, ret) => {
2518            let arg_ty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2519            let ret_ty = type_from_annotation_expr_vars(adts, ret, vars, supply)?;
2520            Ok(Type::fun(arg_ty, ret_ty))
2521        }
2522        TypeExpr::Tuple(_, elems) => {
2523            let mut out = Vec::new();
2524            for elem in elems {
2525                out.push(type_from_annotation_expr_vars(adts, elem, vars, supply)?);
2526            }
2527            Ok(Type::tuple(out))
2528        }
2529        TypeExpr::Record(_, fields) => {
2530            let mut out = Vec::new();
2531            for (name, ty) in fields {
2532                out.push((
2533                    name.clone(),
2534                    type_from_annotation_expr_vars(adts, ty, vars, supply)?,
2535                ));
2536            }
2537            Ok(Type::record(out))
2538        }
2539    })();
2540    res.map_err(|err| with_span(&span, err))
2541}
2542
2543fn annotation_type_arity(adts: &HashMap<Symbol, AdtDecl>, name: &Symbol) -> Option<usize> {
2544    if let Some(adt) = adts.get(name) {
2545        return Some(adt.params.len());
2546    }
2547    BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2548}
2549
2550fn normalize_type_name(name: &Symbol) -> Symbol {
2551    if name.as_ref() == "str" {
2552        BuiltinTypeId::String.as_symbol()
2553    } else {
2554        name.clone()
2555    }
2556}
2557
2558fn type_app_with_result_syntax(fun: Type, arg: Type) -> Type {
2559    if let TypeKind::App(head, ok) = fun.as_ref()
2560        && matches!(
2561            head.as_ref(),
2562            TypeKind::Con(c)
2563                if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
2564        )
2565    {
2566        return Type::app(Type::app(head.clone(), arg), ok.clone());
2567    }
2568    Type::app(fun, arg)
2569}
2570
2571fn predicates_from_constraints(
2572    adts: &HashMap<Symbol, AdtDecl>,
2573    constraints: &[TypeConstraint],
2574    vars: &mut HashMap<Symbol, TypeVar>,
2575    supply: &mut TypeVarSupply,
2576) -> Result<Vec<Predicate>, TypeError> {
2577    let mut out = Vec::with_capacity(constraints.len());
2578    for constraint in constraints {
2579        let ty = type_from_annotation_expr_vars(adts, &constraint.typ, vars, supply)?;
2580        out.push(Predicate::new(constraint.class.as_ref(), ty));
2581    }
2582    Ok(out)
2583}
2584
2585#[derive(Clone, Debug, PartialEq, Eq)]
2586pub struct AdtConflict {
2587    pub name: Symbol,
2588    pub definitions: Vec<Type>,
2589}
2590
2591#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
2592#[error("conflicting ADT definitions: {conflicts:?}")]
2593pub struct CollectAdtsError {
2594    pub conflicts: Vec<AdtConflict>,
2595}
2596
2597/// Collect all user-defined ADT constructors referenced by the provided types.
2598///
2599/// This walks each type recursively (including nested occurrences), returns a
2600/// deduplicated list of constructor heads, and rejects ambiguous constructor
2601/// names that appear with incompatible definitions.
2602///
2603/// The returned `Type`s are constructor heads (for example `Foo`), suitable
2604/// for passing to embedder utilities that derive `AdtDecl`s from type
2605/// constructors.
2606///
2607/// # Examples
2608///
2609/// ```rust,ignore
2610/// use rex_ts::{collect_adts_in_types, BuiltinTypeId, Type};
2611///
2612/// let types = vec![
2613///     Type::app(Type::user_con("Foo", 1), Type::builtin(BuiltinTypeId::I32)),
2614///     Type::fun(Type::user_con("Bar", 0), Type::user_con("Foo", 1)),
2615/// ];
2616///
2617/// let adts = collect_adts_in_types(types).unwrap();
2618/// assert_eq!(adts, vec![Type::user_con("Foo", 1), Type::user_con("Bar", 0)]);
2619/// ```
2620///
2621/// ```rust,ignore
2622/// use rex_ts::{collect_adts_in_types, Type};
2623///
2624/// let err = collect_adts_in_types(vec![
2625///     Type::user_con("Thing", 1),
2626///     Type::user_con("Thing", 2),
2627/// ])
2628/// .unwrap_err();
2629///
2630/// assert_eq!(err.conflicts.len(), 1);
2631/// assert_eq!(err.conflicts[0].name.as_ref(), "Thing");
2632/// ```
2633pub fn collect_adts_in_types(types: Vec<Type>) -> Result<Vec<Type>, CollectAdtsError> {
2634    fn visit(
2635        typ: &Type,
2636        out: &mut Vec<Type>,
2637        seen: &mut HashSet<Type>,
2638        defs_by_name: &mut BTreeMap<Symbol, Vec<Type>>,
2639    ) {
2640        match typ.as_ref() {
2641            TypeKind::Var(_) => {}
2642            TypeKind::Con(tc) => {
2643                // Builtins are not embeddable ADT declarations.
2644                if tc.builtin_id.is_none() {
2645                    let adt = Type::new(TypeKind::Con(tc.clone()));
2646                    if seen.insert(adt.clone()) {
2647                        out.push(adt.clone());
2648                    }
2649                    let defs = defs_by_name.entry(tc.name.clone()).or_default();
2650                    if !defs.contains(&adt) {
2651                        defs.push(adt);
2652                    }
2653                }
2654            }
2655            TypeKind::App(fun, arg) => {
2656                visit(fun, out, seen, defs_by_name);
2657                visit(arg, out, seen, defs_by_name);
2658            }
2659            TypeKind::Fun(arg, ret) => {
2660                visit(arg, out, seen, defs_by_name);
2661                visit(ret, out, seen, defs_by_name);
2662            }
2663            TypeKind::Tuple(elems) => {
2664                for elem in elems {
2665                    visit(elem, out, seen, defs_by_name);
2666                }
2667            }
2668            TypeKind::Record(fields) => {
2669                for (_name, field_ty) in fields {
2670                    visit(field_ty, out, seen, defs_by_name);
2671                }
2672            }
2673        }
2674    }
2675
2676    let mut out = Vec::new();
2677    let mut seen = HashSet::new();
2678    let mut defs_by_name: BTreeMap<Symbol, Vec<Type>> = BTreeMap::new();
2679    for typ in &types {
2680        visit(typ, &mut out, &mut seen, &mut defs_by_name);
2681    }
2682
2683    let conflicts: Vec<AdtConflict> = defs_by_name
2684        .into_iter()
2685        .filter_map(|(name, definitions)| {
2686            (definitions.len() > 1).then_some(AdtConflict { name, definitions })
2687        })
2688        .collect();
2689    if !conflicts.is_empty() {
2690        return Err(CollectAdtsError { conflicts });
2691    }
2692
2693    Ok(out)
2694}