Skip to main content

rex_typesystem/
types.rs

1use crate::{
2    error::{AdtConflict, CollectAdtsError},
3    typesystem::TypeVarSupply,
4    unification::{Subst, subst_is_empty},
5};
6use chrono::{DateTime, Utc};
7use rex_ast::expr::{Pattern, Symbol, intern, sym};
8use rpds::HashTrieMapSync;
9use std::{
10    collections::{BTreeMap, BTreeSet},
11    fmt::{self, Display, Formatter},
12    sync::Arc,
13};
14use uuid::Uuid;
15
16pub type TypeVarId = usize;
17
18#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
19pub enum BuiltinTypeId {
20    U8,
21    U16,
22    U32,
23    U64,
24    I8,
25    I16,
26    I32,
27    I64,
28    F32,
29    F64,
30    Bool,
31    String,
32    Uuid,
33    DateTime,
34    List,
35    Array,
36    Dict,
37    Option,
38    Promise,
39    Result,
40}
41
42impl BuiltinTypeId {
43    pub fn as_symbol(self) -> Symbol {
44        sym(self.as_str())
45    }
46
47    pub fn as_str(self) -> &'static str {
48        match self {
49            Self::U8 => "u8",
50            Self::U16 => "u16",
51            Self::U32 => "u32",
52            Self::U64 => "u64",
53            Self::I8 => "i8",
54            Self::I16 => "i16",
55            Self::I32 => "i32",
56            Self::I64 => "i64",
57            Self::F32 => "f32",
58            Self::F64 => "f64",
59            Self::Bool => "bool",
60            Self::String => "string",
61            Self::Uuid => "uuid",
62            Self::DateTime => "datetime",
63            Self::List => "List",
64            Self::Array => "Array",
65            Self::Dict => "Dict",
66            Self::Option => "Option",
67            Self::Promise => "Promise",
68            Self::Result => "Result",
69        }
70    }
71
72    pub fn arity(self) -> usize {
73        match self {
74            Self::List | Self::Array | Self::Dict | Self::Option | Self::Promise => 1,
75            Self::Result => 2,
76            _ => 0,
77        }
78    }
79
80    pub fn from_symbol(name: &Symbol) -> Option<Self> {
81        Self::from_name(name.as_ref())
82    }
83
84    pub fn from_name(name: &str) -> Option<Self> {
85        match name {
86            "u8" => Some(Self::U8),
87            "u16" => Some(Self::U16),
88            "u32" => Some(Self::U32),
89            "u64" => Some(Self::U64),
90            "i8" => Some(Self::I8),
91            "i16" => Some(Self::I16),
92            "i32" => Some(Self::I32),
93            "i64" => Some(Self::I64),
94            "f32" => Some(Self::F32),
95            "f64" => Some(Self::F64),
96            "bool" => Some(Self::Bool),
97            "string" => Some(Self::String),
98            "uuid" => Some(Self::Uuid),
99            "datetime" => Some(Self::DateTime),
100            "List" => Some(Self::List),
101            "Array" => Some(Self::Array),
102            "Dict" => Some(Self::Dict),
103            "Option" => Some(Self::Option),
104            "Promise" => Some(Self::Promise),
105            "Result" => Some(Self::Result),
106            _ => None,
107        }
108    }
109}
110
111#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
112pub struct TypeVar {
113    pub id: TypeVarId,
114    pub name: Option<Symbol>,
115}
116
117impl TypeVar {
118    pub fn new(id: TypeVarId, name: impl Into<Option<Symbol>>) -> Self {
119        Self {
120            id,
121            name: name.into(),
122        }
123    }
124}
125
126#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
127pub struct TypeConst {
128    pub name: Symbol,
129    pub arity: usize,
130    pub builtin_id: Option<BuiltinTypeId>,
131}
132
133#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
134pub struct Type(Arc<TypeKind>);
135
136#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
137pub enum TypeKind {
138    Var(TypeVar),
139    Con(TypeConst),
140    App(Type, Type),
141    Fun(Type, Type),
142    Tuple(Vec<Type>),
143    /// Record type `{a: T, b: U}`.
144    ///
145    /// Invariant: fields are sorted by name. This makes record equality and
146    /// unification a cheap zip over two vectors, and it makes printing stable.
147    Record(Vec<(Symbol, Type)>),
148}
149
150impl Type {
151    pub fn new(kind: TypeKind) -> Self {
152        Type(Arc::new(kind))
153    }
154
155    pub fn con(name: impl AsRef<str>, arity: usize) -> Self {
156        if let Some(id) = BuiltinTypeId::from_name(name.as_ref())
157            && id.arity() == arity
158        {
159            return Self::builtin(id);
160        }
161        Self::user_con(name, arity)
162    }
163
164    pub fn user_con(name: impl AsRef<str>, arity: usize) -> Self {
165        Type::new(TypeKind::Con(TypeConst {
166            name: intern(name.as_ref()),
167            arity,
168            builtin_id: None,
169        }))
170    }
171
172    pub fn builtin(id: BuiltinTypeId) -> Self {
173        Type::new(TypeKind::Con(TypeConst {
174            name: id.as_symbol(),
175            arity: id.arity(),
176            builtin_id: Some(id),
177        }))
178    }
179
180    pub fn var(tv: TypeVar) -> Self {
181        Type::new(TypeKind::Var(tv))
182    }
183
184    pub fn fun(a: Type, b: Type) -> Self {
185        Type::new(TypeKind::Fun(a, b))
186    }
187
188    pub fn app(f: Type, arg: Type) -> Self {
189        Type::new(TypeKind::App(f, arg))
190    }
191
192    pub fn tuple(elems: Vec<Type>) -> Self {
193        Type::new(TypeKind::Tuple(elems))
194    }
195
196    pub fn record(mut fields: Vec<(Symbol, Type)>) -> Self {
197        // Canonicalize records so downstream code can rely on “same shape means
198        // same ordering”. (This is a correctness invariant, not a nicety.)
199        fields.sort_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
200        Type::new(TypeKind::Record(fields))
201    }
202
203    pub fn list(elem: Type) -> Type {
204        Type::app(Type::builtin(BuiltinTypeId::List), elem)
205    }
206
207    pub fn array(elem: Type) -> Type {
208        Type::app(Type::builtin(BuiltinTypeId::Array), elem)
209    }
210
211    pub fn dict(elem: Type) -> Type {
212        Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
213    }
214
215    pub fn option(elem: Type) -> Type {
216        Type::app(Type::builtin(BuiltinTypeId::Option), elem)
217    }
218
219    pub fn promise(elem: Type) -> Type {
220        Type::app(Type::builtin(BuiltinTypeId::Promise), elem)
221    }
222
223    pub fn result(ok: Type, err: Type) -> Type {
224        Type::app(Type::app(Type::builtin(BuiltinTypeId::Result), err), ok)
225    }
226
227    fn apply_with_change(&self, s: &Subst) -> (Type, bool) {
228        match self.as_ref() {
229            TypeKind::Var(tv) => match s.get(&tv.id) {
230                Some(ty) => (ty.clone(), true),
231                None => (self.clone(), false),
232            },
233            TypeKind::Con(_) => (self.clone(), false),
234            TypeKind::App(l, r) => {
235                let (l_new, l_changed) = l.apply_with_change(s);
236                let (r_new, r_changed) = r.apply_with_change(s);
237                if l_changed || r_changed {
238                    (Type::app(l_new, r_new), true)
239                } else {
240                    (self.clone(), false)
241                }
242            }
243            TypeKind::Fun(_, _) => {
244                // Avoid recursive descent on long function chains like
245                // `a1 -> a2 -> ... -> an -> r`.
246                let mut args = Vec::new();
247                let mut changed = false;
248                let mut cur: &Type = self;
249                while let TypeKind::Fun(a, b) = cur.as_ref() {
250                    let (a_new, a_changed) = a.apply_with_change(s);
251                    changed |= a_changed;
252                    args.push(a_new);
253                    cur = b;
254                }
255                let (ret_new, ret_changed) = cur.apply_with_change(s);
256                changed |= ret_changed;
257                if !changed {
258                    return (self.clone(), false);
259                }
260                let mut out = ret_new;
261                for a_new in args.into_iter().rev() {
262                    out = Type::fun(a_new, out);
263                }
264                (out, true)
265            }
266            TypeKind::Tuple(ts) => {
267                let mut changed = false;
268                let mut out = Vec::with_capacity(ts.len());
269                for t in ts {
270                    let (t_new, t_changed) = t.apply_with_change(s);
271                    changed |= t_changed;
272                    out.push(t_new);
273                }
274                if changed {
275                    (Type::new(TypeKind::Tuple(out)), true)
276                } else {
277                    (self.clone(), false)
278                }
279            }
280            TypeKind::Record(fields) => {
281                let mut changed = false;
282                let mut out = Vec::with_capacity(fields.len());
283                for (k, v) in fields {
284                    let (v_new, v_changed) = v.apply_with_change(s);
285                    changed |= v_changed;
286                    out.push((k.clone(), v_new));
287                }
288                if changed {
289                    (Type::new(TypeKind::Record(out)), true)
290                } else {
291                    (self.clone(), false)
292                }
293            }
294        }
295    }
296
297    pub fn for_each<F>(&self, mut f: F) -> Type
298    where
299        F: FnMut(&Type),
300    {
301        self.transform(|t| {
302            f(t);
303            None
304        })
305    }
306
307    pub fn transform<F>(&self, mut f: F) -> Type
308    where
309        F: FnMut(&Type) -> Option<Type>,
310    {
311        self.transform_ref(&mut f)
312    }
313
314    fn transform_ref<F>(&self, f: &mut F) -> Type
315    where
316        F: FnMut(&Type) -> Option<Type>,
317    {
318        if let Some(repl) = f(self) {
319            return repl;
320        }
321
322        match self.as_ref() {
323            TypeKind::Var(type_var) => Type(Arc::new(TypeKind::Var(type_var.clone()))),
324            TypeKind::Con(type_const) => Type(Arc::new(TypeKind::Con(type_const.clone()))),
325            TypeKind::App(fun, arg) => Type(Arc::new(TypeKind::App(
326                fun.transform_ref(f),
327                arg.transform_ref(f),
328            ))),
329            TypeKind::Fun(arg, res) => Type(Arc::new(TypeKind::Fun(
330                arg.transform_ref(f),
331                res.transform_ref(f),
332            ))),
333            TypeKind::Tuple(ts) => Type(Arc::new(TypeKind::Tuple(
334                ts.iter().map(|t| t.transform_ref(f)).collect(),
335            ))),
336            TypeKind::Record(fields) => Type(Arc::new(TypeKind::Record(
337                fields
338                    .iter()
339                    .map(|(s, t)| (s.clone(), t.transform_ref(f)))
340                    .collect(),
341            ))),
342        }
343    }
344}
345
346impl AsRef<TypeKind> for Type {
347    fn as_ref(&self) -> &TypeKind {
348        self.0.as_ref()
349    }
350}
351
352impl std::ops::Deref for Type {
353    type Target = TypeKind;
354
355    fn deref(&self) -> &Self::Target {
356        &self.0
357    }
358}
359
360impl Display for Type {
361    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
362        match self.as_ref() {
363            TypeKind::Var(tv) => match &tv.name {
364                Some(name) => write!(f, "'{}", name),
365                None => write!(f, "t{}", tv.id),
366            },
367            TypeKind::Con(c) => write!(f, "{}", c.name),
368            TypeKind::App(l, r) => {
369                // Internally `Result` is represented as `Result err ok` so it can be partially
370                // applied as `Result err` for HKTs (Functor/Monad/etc).
371                //
372                // User-facing syntax is `Result ok err` (Rust-style), so render the fully
373                // applied form with swapped arguments.
374                if let TypeKind::App(head, err) = l.as_ref()
375                    && matches!(
376                        head.as_ref(),
377                        TypeKind::Con(c)
378                            if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
379                    )
380                {
381                    return write!(f, "(Result {} {})", r, err);
382                }
383                write!(f, "({} {})", l, r)
384            }
385            TypeKind::Fun(a, b) => write!(f, "({} -> {})", a, b),
386            TypeKind::Tuple(elems) => {
387                write!(f, "(")?;
388                for (i, t) in elems.iter().enumerate() {
389                    write!(f, "{}", t)?;
390                    if i + 1 < elems.len() {
391                        write!(f, ", ")?;
392                    }
393                }
394                write!(f, ")")
395            }
396            TypeKind::Record(fields) => {
397                write!(f, "{{")?;
398                for (i, (name, ty)) in fields.iter().enumerate() {
399                    write!(f, "{}: {}", name, ty)?;
400                    if i + 1 < fields.len() {
401                        write!(f, ", ")?;
402                    }
403                }
404                write!(f, "}}")
405            }
406        }
407    }
408}
409
410#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
411pub struct Predicate {
412    pub class: Symbol,
413    pub typ: Type,
414}
415
416impl Predicate {
417    pub fn new(class: impl AsRef<str>, typ: Type) -> Self {
418        Self {
419            class: intern(class.as_ref()),
420            typ,
421        }
422    }
423}
424
425#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
426pub struct Scheme {
427    pub vars: Vec<TypeVar>,
428    pub preds: Vec<Predicate>,
429    pub typ: Type,
430}
431
432impl Scheme {
433    pub fn new(vars: Vec<TypeVar>, preds: Vec<Predicate>, typ: Type) -> Self {
434        Self { vars, preds, typ }
435    }
436}
437
438pub trait Types: Sized {
439    fn apply(&self, s: &Subst) -> Self;
440    fn ftv(&self) -> BTreeSet<TypeVarId>;
441}
442
443impl Types for Type {
444    fn apply(&self, s: &Subst) -> Self {
445        self.apply_with_change(s).0
446    }
447
448    fn ftv(&self) -> BTreeSet<TypeVarId> {
449        let mut out = BTreeSet::new();
450        let mut stack: Vec<&Type> = vec![self];
451        while let Some(t) = stack.pop() {
452            match t.as_ref() {
453                TypeKind::Var(tv) => {
454                    out.insert(tv.id);
455                }
456                TypeKind::Con(_) => {}
457                TypeKind::App(l, r) => {
458                    stack.push(l);
459                    stack.push(r);
460                }
461                TypeKind::Fun(a, b) => {
462                    stack.push(a);
463                    stack.push(b);
464                }
465                TypeKind::Tuple(ts) => {
466                    for t in ts {
467                        stack.push(t);
468                    }
469                }
470                TypeKind::Record(fields) => {
471                    for (_, ty) in fields {
472                        stack.push(ty);
473                    }
474                }
475            }
476        }
477        out
478    }
479}
480
481impl Types for Predicate {
482    fn apply(&self, s: &Subst) -> Self {
483        Predicate {
484            class: self.class.clone(),
485            typ: self.typ.apply(s),
486        }
487    }
488
489    fn ftv(&self) -> BTreeSet<TypeVarId> {
490        self.typ.ftv()
491    }
492}
493
494impl Types for Scheme {
495    fn apply(&self, s: &Subst) -> Self {
496        let mut s_pruned = Subst::new_sync();
497        for (k, v) in s.iter() {
498            if !self.vars.iter().any(|var| var.id == *k) {
499                s_pruned = s_pruned.insert(*k, v.clone());
500            }
501        }
502        Scheme::new(
503            self.vars.clone(),
504            self.preds.iter().map(|p| p.apply(&s_pruned)).collect(),
505            self.typ.apply(&s_pruned),
506        )
507    }
508
509    fn ftv(&self) -> BTreeSet<TypeVarId> {
510        let mut ftv = self.typ.ftv();
511        for p in &self.preds {
512            ftv.extend(p.ftv());
513        }
514        for v in &self.vars {
515            ftv.remove(&v.id);
516        }
517        ftv
518    }
519}
520
521impl<T: Types> Types for Vec<T> {
522    fn apply(&self, s: &Subst) -> Self {
523        self.iter().map(|t| t.apply(s)).collect()
524    }
525
526    fn ftv(&self) -> BTreeSet<TypeVarId> {
527        self.iter().flat_map(Types::ftv).collect()
528    }
529}
530
531#[derive(Clone, Debug, PartialEq)]
532pub struct TypedExpr {
533    pub typ: Type,
534    pub kind: TypedExprKind,
535}
536
537impl TypedExpr {
538    pub fn new(typ: Type, kind: TypedExprKind) -> Self {
539        Self { typ, kind }
540    }
541
542    pub fn apply(&self, s: &Subst) -> Self {
543        match &self.kind {
544            TypedExprKind::Lam { .. } => {
545                let mut params: Vec<(Symbol, Type)> = Vec::new();
546                let mut cur = self;
547                while let TypedExprKind::Lam { param, body } = &cur.kind {
548                    params.push((param.clone(), cur.typ.apply(s)));
549                    cur = body.as_ref();
550                }
551                let mut out = cur.apply(s);
552                for (param, typ) in params.into_iter().rev() {
553                    out = TypedExpr {
554                        typ,
555                        kind: TypedExprKind::Lam {
556                            param,
557                            body: Box::new(out),
558                        },
559                    };
560                }
561                return out;
562            }
563            TypedExprKind::App(..) => {
564                let mut apps: Vec<(Type, &TypedExpr)> = Vec::new();
565                let mut cur = self;
566                while let TypedExprKind::App(f, x) = &cur.kind {
567                    apps.push((cur.typ.apply(s), x.as_ref()));
568                    cur = f.as_ref();
569                }
570                let mut out = cur.apply(s);
571                for (typ, arg) in apps.into_iter().rev() {
572                    out = TypedExpr {
573                        typ,
574                        kind: TypedExprKind::App(Box::new(out), Box::new(arg.apply(s))),
575                    };
576                }
577                return out;
578            }
579            _ => {}
580        }
581
582        let typ = self.typ.apply(s);
583        let kind = match &self.kind {
584            TypedExprKind::Bool(v) => TypedExprKind::Bool(*v),
585            TypedExprKind::Uint(v) => TypedExprKind::Uint(*v),
586            TypedExprKind::Int(v) => TypedExprKind::Int(*v),
587            TypedExprKind::Float(v) => TypedExprKind::Float(*v),
588            TypedExprKind::String(v) => TypedExprKind::String(v.clone()),
589            TypedExprKind::Uuid(v) => TypedExprKind::Uuid(*v),
590            TypedExprKind::DateTime(v) => TypedExprKind::DateTime(*v),
591            TypedExprKind::Hole => TypedExprKind::Hole,
592            TypedExprKind::Tuple(elems) => {
593                TypedExprKind::Tuple(elems.iter().map(|e| e.apply(s)).collect())
594            }
595            TypedExprKind::List(elems) => {
596                TypedExprKind::List(elems.iter().map(|e| e.apply(s)).collect())
597            }
598            TypedExprKind::Dict(kvs) => {
599                let mut out = BTreeMap::new();
600                for (k, v) in kvs {
601                    out.insert(k.clone(), v.apply(s));
602                }
603                TypedExprKind::Dict(out)
604            }
605            TypedExprKind::RecordUpdate { base, updates } => {
606                let mut out = BTreeMap::new();
607                for (k, v) in updates {
608                    out.insert(k.clone(), v.apply(s));
609                }
610                TypedExprKind::RecordUpdate {
611                    base: Box::new(base.apply(s)),
612                    updates: out,
613                }
614            }
615            TypedExprKind::Var { name, overloads } => TypedExprKind::Var {
616                name: name.clone(),
617                overloads: overloads.iter().map(|t| t.apply(s)).collect(),
618            },
619            TypedExprKind::App(f, x) => {
620                TypedExprKind::App(Box::new(f.apply(s)), Box::new(x.apply(s)))
621            }
622            TypedExprKind::Project { expr, field } => TypedExprKind::Project {
623                expr: Box::new(expr.apply(s)),
624                field: field.clone(),
625            },
626            TypedExprKind::Lam { param, body } => TypedExprKind::Lam {
627                param: param.clone(),
628                body: Box::new(body.apply(s)),
629            },
630            TypedExprKind::Let { name, def, body } => TypedExprKind::Let {
631                name: name.clone(),
632                def: Box::new(def.apply(s)),
633                body: Box::new(body.apply(s)),
634            },
635            TypedExprKind::LetRec { bindings, body } => TypedExprKind::LetRec {
636                bindings: bindings
637                    .iter()
638                    .map(|(name, def)| (name.clone(), def.apply(s)))
639                    .collect(),
640                body: Box::new(body.apply(s)),
641            },
642            TypedExprKind::Ite {
643                cond,
644                then_expr,
645                else_expr,
646            } => TypedExprKind::Ite {
647                cond: Box::new(cond.apply(s)),
648                then_expr: Box::new(then_expr.apply(s)),
649                else_expr: Box::new(else_expr.apply(s)),
650            },
651            TypedExprKind::Match { scrutinee, arms } => TypedExprKind::Match {
652                scrutinee: Box::new(scrutinee.apply(s)),
653                arms: arms.iter().map(|(p, e)| (p.clone(), e.apply(s))).collect(),
654            },
655        };
656        TypedExpr { typ, kind }
657    }
658}
659
660#[derive(Clone, Debug, PartialEq)]
661pub enum TypedExprKind {
662    Bool(bool),
663    Uint(u64),
664    Int(i64),
665    Float(f64),
666    String(String),
667    Uuid(Uuid),
668    DateTime(DateTime<Utc>),
669    Hole,
670    Tuple(Vec<TypedExpr>),
671    List(Vec<TypedExpr>),
672    Dict(BTreeMap<Symbol, TypedExpr>),
673    RecordUpdate {
674        base: Box<TypedExpr>,
675        updates: BTreeMap<Symbol, TypedExpr>,
676    },
677    Var {
678        name: Symbol,
679        overloads: Vec<Type>,
680    },
681    App(Box<TypedExpr>, Box<TypedExpr>),
682    Project {
683        expr: Box<TypedExpr>,
684        field: Symbol,
685    },
686    Lam {
687        param: Symbol,
688        body: Box<TypedExpr>,
689    },
690    Let {
691        name: Symbol,
692        def: Box<TypedExpr>,
693        body: Box<TypedExpr>,
694    },
695    LetRec {
696        bindings: Vec<(Symbol, TypedExpr)>,
697        body: Box<TypedExpr>,
698    },
699    Ite {
700        cond: Box<TypedExpr>,
701        then_expr: Box<TypedExpr>,
702        else_expr: Box<TypedExpr>,
703    },
704    Match {
705        scrutinee: Box<TypedExpr>,
706        arms: Vec<(Pattern, TypedExpr)>,
707    },
708}
709
710#[derive(Default, Debug, Clone)]
711pub struct TypeEnv {
712    pub values: HashTrieMapSync<Symbol, Vec<Scheme>>,
713}
714
715impl TypeEnv {
716    pub fn new() -> Self {
717        Self {
718            values: HashTrieMapSync::new_sync(),
719        }
720    }
721
722    pub fn extend(&mut self, name: Symbol, scheme: Scheme) {
723        self.values = self.values.insert(name, vec![scheme]);
724    }
725
726    pub fn extend_overload(&mut self, name: Symbol, scheme: Scheme) {
727        let mut schemes = self.values.get(&name).cloned().unwrap_or_default();
728        schemes.push(scheme);
729        self.values = self.values.insert(name, schemes);
730    }
731
732    pub fn remove(&mut self, name: &Symbol) {
733        self.values = self.values.remove(name);
734    }
735
736    pub fn lookup(&self, name: &Symbol) -> Option<&[Scheme]> {
737        self.values.get(name).map(|schemes| schemes.as_slice())
738    }
739}
740
741impl Types for TypeEnv {
742    fn apply(&self, s: &Subst) -> Self {
743        let mut values = HashTrieMapSync::new_sync();
744        for (k, v) in self.values.iter() {
745            let updated = v
746                .iter()
747                .map(|scheme| {
748                    // Most schemes in environments are monomorphic. Don't walk
749                    // and rebuild trees unless we actually have work to do.
750                    if scheme.vars.is_empty() && !subst_is_empty(s) {
751                        scheme.apply(s)
752                    } else {
753                        scheme.clone()
754                    }
755                })
756                .collect();
757            values = values.insert(k.clone(), updated);
758        }
759        TypeEnv { values }
760    }
761
762    fn ftv(&self) -> BTreeSet<TypeVarId> {
763        self.values
764            .iter()
765            .flat_map(|(_, schemes)| schemes.iter().flat_map(Types::ftv))
766            .collect()
767    }
768}
769
770/// A named type parameter for an ADT (e.g. `a` in `List a`).
771#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
772pub struct AdtParam {
773    pub name: Symbol,
774    pub var: TypeVar,
775}
776
777/// A single ADT variant with zero or more constructor arguments.
778#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
779pub struct AdtVariant {
780    pub name: Symbol,
781    pub args: Vec<Type>,
782}
783
784/// A type declaration for an algebraic data type.
785///
786/// This only describes the *type* surface (params + variants). It does not
787/// introduce any runtime values by itself. Runtime values are created by
788/// injecting constructor schemes into the environment (see `inject_adt`).
789#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
790pub struct AdtDecl {
791    pub name: Symbol,
792    pub params: Vec<AdtParam>,
793    pub variants: Vec<AdtVariant>,
794}
795
796impl AdtDecl {
797    pub fn new(name: &Symbol, param_names: &[Symbol], supply: &mut TypeVarSupply) -> Self {
798        let params = param_names
799            .iter()
800            .map(|p| AdtParam {
801                name: p.clone(),
802                var: supply.fresh(Some(p.clone())),
803            })
804            .collect();
805        Self {
806            name: name.clone(),
807            params,
808            variants: Vec::new(),
809        }
810    }
811
812    pub fn param_type(&self, name: &Symbol) -> Option<Type> {
813        self.params
814            .iter()
815            .find(|p| &p.name == name)
816            .map(|p| Type::var(p.var.clone()))
817    }
818
819    pub fn add_variant(&mut self, name: Symbol, args: Vec<Type>) {
820        self.variants.push(AdtVariant { name, args });
821    }
822
823    pub fn result_type(&self) -> Type {
824        let mut ty = Type::con(&self.name, self.params.len());
825        for param in &self.params {
826            ty = Type::app(ty, Type::var(param.var.clone()));
827        }
828        ty
829    }
830
831    /// Build constructor schemes of the form:
832    /// `C :: a1 -> a2 -> ... -> T params`.
833    pub fn constructor_schemes(&self) -> Vec<(Symbol, Scheme)> {
834        let result_ty = self.result_type();
835        let vars: Vec<TypeVar> = self.params.iter().map(|p| p.var.clone()).collect();
836        let mut out = Vec::new();
837        for variant in &self.variants {
838            let mut typ = result_ty.clone();
839            for arg in variant.args.iter().rev() {
840                typ = Type::fun(arg.clone(), typ);
841            }
842            out.push((variant.name.clone(), Scheme::new(vars.clone(), vec![], typ)));
843        }
844        out
845    }
846}
847
848#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
849pub struct Class {
850    pub supers: Vec<Symbol>,
851}
852
853impl Class {
854    pub fn new(supers: Vec<Symbol>) -> Self {
855        Self { supers }
856    }
857}
858
859#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
860pub struct Instance {
861    pub context: Vec<Predicate>,
862    pub head: Predicate,
863}
864
865impl Instance {
866    pub fn new(context: Vec<Predicate>, head: Predicate) -> Self {
867        Self { context, head }
868    }
869}
870
871#[derive(Default, Debug, Clone)]
872pub struct ClassEnv {
873    pub classes: BTreeMap<Symbol, Class>,
874    pub instances: BTreeMap<Symbol, Vec<Instance>>,
875}
876
877impl ClassEnv {
878    pub fn new() -> Self {
879        Self {
880            classes: BTreeMap::new(),
881            instances: BTreeMap::new(),
882        }
883    }
884
885    pub fn add_class(&mut self, name: Symbol, supers: Vec<Symbol>) {
886        self.classes.insert(name, Class::new(supers));
887    }
888
889    pub fn add_instance(&mut self, class: Symbol, inst: Instance) {
890        self.instances.entry(class).or_default().push(inst);
891    }
892
893    pub fn supers_of(&self, class: &Symbol) -> Vec<Symbol> {
894        self.classes
895            .get(class)
896            .map(|c| c.supers.clone())
897            .unwrap_or_default()
898    }
899}
900
901/// Collect all user-defined ADT constructors referenced by the provided types.
902///
903/// This walks each type recursively (including nested occurrences), returns a
904/// deduplicated list of constructor heads, and rejects ambiguous constructor
905/// names that appear with incompatible definitions.
906///
907/// The returned `Type`s are constructor heads (for example `Foo`), suitable
908/// for passing to embedder utilities that derive `AdtDecl`s from type
909/// constructors.
910///
911/// # Examples
912///
913/// ```rust,ignore
914/// use rex_ts::{collect_adts_in_types, BuiltinTypeId, Type};
915///
916/// let types = vec![
917///     Type::app(Type::user_con("Foo", 1), Type::builtin(BuiltinTypeId::I32)),
918///     Type::fun(Type::user_con("Bar", 0), Type::user_con("Foo", 1)),
919/// ];
920///
921/// let adts = collect_adts_in_types(types).unwrap();
922/// assert_eq!(adts, vec![Type::user_con("Foo", 1), Type::user_con("Bar", 0)]);
923/// ```
924///
925/// ```rust,ignore
926/// use rex_ts::{collect_adts_in_types, Type};
927///
928/// let err = collect_adts_in_types(vec![
929///     Type::user_con("Thing", 1),
930///     Type::user_con("Thing", 2),
931/// ])
932/// .unwrap_err();
933///
934/// assert_eq!(err.conflicts.len(), 1);
935/// assert_eq!(err.conflicts[0].name.as_ref(), "Thing");
936/// ```
937pub fn collect_adts_in_types(types: Vec<Type>) -> Result<Vec<Type>, CollectAdtsError> {
938    let mut out = Vec::new();
939    let mut seen = BTreeSet::new();
940    let mut defs_by_name: BTreeMap<Symbol, Vec<Type>> = BTreeMap::new();
941    for typ in &types {
942        typ.for_each(|t| {
943            if let TypeKind::Con(tc) = t.as_ref() {
944                // Builtins are not embeddable ADT declarations.
945                if tc.builtin_id.is_none() {
946                    let adt = Type::new(TypeKind::Con(tc.clone()));
947                    if seen.insert(adt.clone()) {
948                        out.push(adt.clone());
949                    }
950                    let defs = defs_by_name.entry(tc.name.clone()).or_default();
951                    if !defs.contains(&adt) {
952                        defs.push(adt);
953                    }
954                }
955            }
956        });
957    }
958
959    let conflicts: Vec<AdtConflict> = defs_by_name
960        .into_iter()
961        .filter_map(|(name, definitions)| {
962            (definitions.len() > 1).then_some(AdtConflict { name, definitions })
963        })
964        .collect();
965    if !conflicts.is_empty() {
966        return Err(CollectAdtsError { conflicts });
967    }
968
969    Ok(out)
970}