Skip to main content

rex_typesystem/
types.rs

1use crate::{
2    error::{AdtConflict, CollectAdtsError, TypeError},
3    typesystem::TypeVarSupply,
4    unification::{Subst, subst_is_empty},
5};
6use chrono::{DateTime, Utc};
7use rex_ast::{Pattern, Symbol};
8use rpds::HashTrieMapSync;
9use std::{
10    cmp::Ordering,
11    collections::{BTreeMap, BTreeSet},
12    fmt::{self, Display, Formatter},
13    mem,
14    sync::Arc,
15};
16use uuid::Uuid;
17
18pub type TypeVarId = usize;
19
20#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
21pub enum BuiltinTypeId {
22    U8,
23    U16,
24    U32,
25    U64,
26    I8,
27    I16,
28    I32,
29    I64,
30    F32,
31    F64,
32    Bool,
33    String,
34    Uuid,
35    DateTime,
36    List,
37    Array,
38    Dict,
39    Option,
40    Promise,
41    Result,
42}
43
44impl BuiltinTypeId {
45    pub fn as_symbol(self) -> Symbol {
46        Symbol::intern(self.as_str())
47    }
48
49    pub fn as_str(self) -> &'static str {
50        match self {
51            Self::U8 => "u8",
52            Self::U16 => "u16",
53            Self::U32 => "u32",
54            Self::U64 => "u64",
55            Self::I8 => "i8",
56            Self::I16 => "i16",
57            Self::I32 => "i32",
58            Self::I64 => "i64",
59            Self::F32 => "f32",
60            Self::F64 => "f64",
61            Self::Bool => "bool",
62            Self::String => "string",
63            Self::Uuid => "uuid",
64            Self::DateTime => "datetime",
65            Self::List => "List",
66            Self::Array => "Array",
67            Self::Dict => "Dict",
68            Self::Option => "Option",
69            Self::Promise => "Promise",
70            Self::Result => "Result",
71        }
72    }
73
74    pub fn arity(self) -> usize {
75        match self {
76            Self::List | Self::Array | Self::Dict | Self::Option | Self::Promise => 1,
77            Self::Result => 2,
78            _ => 0,
79        }
80    }
81
82    pub fn from_symbol(name: &Symbol) -> Option<Self> {
83        Self::from_name(name.as_ref())
84    }
85
86    pub fn from_name(name: &str) -> Option<Self> {
87        match name {
88            "u8" => Some(Self::U8),
89            "u16" => Some(Self::U16),
90            "u32" => Some(Self::U32),
91            "u64" => Some(Self::U64),
92            "i8" => Some(Self::I8),
93            "i16" => Some(Self::I16),
94            "i32" => Some(Self::I32),
95            "i64" => Some(Self::I64),
96            "f32" => Some(Self::F32),
97            "f64" => Some(Self::F64),
98            "bool" => Some(Self::Bool),
99            "string" => Some(Self::String),
100            "uuid" => Some(Self::Uuid),
101            "datetime" => Some(Self::DateTime),
102            "List" => Some(Self::List),
103            "Array" => Some(Self::Array),
104            "Dict" => Some(Self::Dict),
105            "Option" => Some(Self::Option),
106            "Promise" => Some(Self::Promise),
107            "Result" => Some(Self::Result),
108            _ => None,
109        }
110    }
111}
112
113#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
114pub struct TypeVar {
115    pub id: TypeVarId,
116    pub name: Option<Symbol>,
117}
118
119impl TypeVar {
120    pub fn new(id: TypeVarId, name: impl Into<Option<Symbol>>) -> Self {
121        Self {
122            id,
123            name: name.into(),
124        }
125    }
126}
127
128#[derive(Clone, Debug, Hash, Eq, PartialEq)]
129pub enum TypeConst {
130    Builtin(BuiltinTypeId),
131    User { name: Symbol, arity: usize },
132}
133
134impl TypeConst {
135    pub fn builtin_id(&self) -> Option<BuiltinTypeId> {
136        match self {
137            Self::Builtin(id) => Some(*id),
138            Self::User { .. } => None,
139        }
140    }
141
142    pub fn is_builtin(&self, id: BuiltinTypeId) -> bool {
143        self.builtin_id() == Some(id)
144    }
145
146    pub fn name(&self) -> Symbol {
147        match self {
148            Self::Builtin(id) => id.as_symbol(),
149            Self::User { name, .. } => name.clone(),
150        }
151    }
152
153    pub fn name_str(&self) -> &str {
154        match self {
155            Self::Builtin(id) => id.as_str(),
156            Self::User { name, .. } => name.as_ref(),
157        }
158    }
159
160    pub fn user_name(&self) -> Option<&Symbol> {
161        match self {
162            Self::Builtin(_) => None,
163            Self::User { name, .. } => Some(name),
164        }
165    }
166
167    pub fn arity(&self) -> usize {
168        match self {
169            Self::Builtin(id) => id.arity(),
170            Self::User { arity, .. } => *arity,
171        }
172    }
173}
174
175impl Ord for TypeConst {
176    fn cmp(&self, other: &Self) -> Ordering {
177        self.name_str()
178            .cmp(other.name_str())
179            .then_with(|| self.arity().cmp(&other.arity()))
180            .then_with(|| self.builtin_id().cmp(&other.builtin_id()))
181    }
182}
183
184impl PartialOrd for TypeConst {
185    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
186        Some(self.cmp(other))
187    }
188}
189
190#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
191pub struct Type(Arc<TypeKind>);
192
193#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
194pub enum TypeKind {
195    Var(TypeVar),
196    Con(TypeConst),
197    App(Type, Type),
198    Fun(Type, Type),
199    Tuple(Vec<Type>),
200    /// Record type `{a: T, b: U}`.
201    ///
202    /// Invariant: fields are sorted by name. This makes record equality and
203    /// unification a cheap zip over two vectors, and it makes printing stable.
204    Record(Vec<(Symbol, Type)>),
205}
206
207impl Type {
208    pub fn new(kind: TypeKind) -> Self {
209        Type(Arc::new(kind))
210    }
211
212    pub fn con(name: impl AsRef<str>, arity: usize) -> Self {
213        if let Some(id) = BuiltinTypeId::from_name(name.as_ref())
214            && id.arity() == arity
215        {
216            return Self::builtin(id);
217        }
218        Self::user_con(name, arity)
219    }
220
221    pub fn user_con(name: impl AsRef<str>, arity: usize) -> Self {
222        Type::new(TypeKind::Con(TypeConst::User {
223            name: Symbol::intern(name.as_ref()),
224            arity,
225        }))
226    }
227
228    pub fn builtin(id: BuiltinTypeId) -> Self {
229        Type::new(TypeKind::Con(TypeConst::Builtin(id)))
230    }
231
232    pub fn var(tv: TypeVar) -> Self {
233        Type::new(TypeKind::Var(tv))
234    }
235
236    pub fn fun(a: Type, b: Type) -> Self {
237        Type::new(TypeKind::Fun(a, b))
238    }
239
240    pub fn app(f: Type, arg: Type) -> Self {
241        Type::new(TypeKind::App(f, arg))
242    }
243
244    pub fn tuple(elems: Vec<Type>) -> Self {
245        Type::new(TypeKind::Tuple(elems))
246    }
247
248    pub fn record(mut fields: Vec<(Symbol, Type)>) -> Self {
249        // Canonicalize records so downstream code can rely on “same shape means
250        // same ordering”. (This is a correctness invariant, not a nicety.)
251        fields.sort_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
252        Type::new(TypeKind::Record(fields))
253    }
254
255    pub fn list(elem: Type) -> Type {
256        Type::app(Type::builtin(BuiltinTypeId::List), elem)
257    }
258
259    pub fn array(elem: Type) -> Type {
260        Type::app(Type::builtin(BuiltinTypeId::Array), elem)
261    }
262
263    pub fn dict(elem: Type) -> Type {
264        Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
265    }
266
267    pub fn option(elem: Type) -> Type {
268        Type::app(Type::builtin(BuiltinTypeId::Option), elem)
269    }
270
271    pub fn promise(elem: Type) -> Type {
272        Type::app(Type::builtin(BuiltinTypeId::Promise), elem)
273    }
274
275    pub fn result(ok: Type, err: Type) -> Type {
276        Type::app(Type::app(Type::builtin(BuiltinTypeId::Result), err), ok)
277    }
278
279    fn apply_with_change(&self, s: &Subst) -> (Type, bool) {
280        match self.as_ref() {
281            TypeKind::Var(tv) => match s.get(&tv.id) {
282                Some(ty) => (ty.clone(), true),
283                None => (self.clone(), false),
284            },
285            TypeKind::Con(_) => (self.clone(), false),
286            TypeKind::App(l, r) => {
287                let (l_new, l_changed) = l.apply_with_change(s);
288                let (r_new, r_changed) = r.apply_with_change(s);
289                if l_changed || r_changed {
290                    (Type::app(l_new, r_new), true)
291                } else {
292                    (self.clone(), false)
293                }
294            }
295            TypeKind::Fun(_, _) => {
296                // Avoid recursive descent on long function chains like
297                // `a1 -> a2 -> ... -> an -> r`.
298                let mut args = Vec::new();
299                let mut changed = false;
300                let mut cur: &Type = self;
301                while let TypeKind::Fun(a, b) = cur.as_ref() {
302                    let (a_new, a_changed) = a.apply_with_change(s);
303                    changed |= a_changed;
304                    args.push(a_new);
305                    cur = b;
306                }
307                let (ret_new, ret_changed) = cur.apply_with_change(s);
308                changed |= ret_changed;
309                if !changed {
310                    return (self.clone(), false);
311                }
312                let mut out = ret_new;
313                for a_new in args.into_iter().rev() {
314                    out = Type::fun(a_new, out);
315                }
316                (out, true)
317            }
318            TypeKind::Tuple(ts) => {
319                let mut changed = false;
320                let mut out = Vec::with_capacity(ts.len());
321                for t in ts {
322                    let (t_new, t_changed) = t.apply_with_change(s);
323                    changed |= t_changed;
324                    out.push(t_new);
325                }
326                if changed {
327                    (Type::new(TypeKind::Tuple(out)), true)
328                } else {
329                    (self.clone(), false)
330                }
331            }
332            TypeKind::Record(fields) => {
333                let mut changed = false;
334                let mut out = Vec::with_capacity(fields.len());
335                for (k, v) in fields {
336                    let (v_new, v_changed) = v.apply_with_change(s);
337                    changed |= v_changed;
338                    out.push((k.clone(), v_new));
339                }
340                if changed {
341                    (Type::new(TypeKind::Record(out)), true)
342                } else {
343                    (self.clone(), false)
344                }
345            }
346        }
347    }
348
349    pub fn for_each<F>(&self, mut f: F) -> Type
350    where
351        F: FnMut(&Type),
352    {
353        self.transform(|t| {
354            f(t);
355            None
356        })
357    }
358
359    pub fn transform<F>(&self, mut f: F) -> Type
360    where
361        F: FnMut(&Type) -> Option<Type>,
362    {
363        self.transform_ref(&mut f)
364    }
365
366    fn transform_ref<F>(&self, f: &mut F) -> Type
367    where
368        F: FnMut(&Type) -> Option<Type>,
369    {
370        if let Some(repl) = f(self) {
371            return repl;
372        }
373
374        match self.as_ref() {
375            TypeKind::Var(type_var) => Type(Arc::new(TypeKind::Var(type_var.clone()))),
376            TypeKind::Con(type_const) => Type(Arc::new(TypeKind::Con(type_const.clone()))),
377            TypeKind::App(fun, arg) => Type(Arc::new(TypeKind::App(
378                fun.transform_ref(f),
379                arg.transform_ref(f),
380            ))),
381            TypeKind::Fun(arg, res) => Type(Arc::new(TypeKind::Fun(
382                arg.transform_ref(f),
383                res.transform_ref(f),
384            ))),
385            TypeKind::Tuple(ts) => Type(Arc::new(TypeKind::Tuple(
386                ts.iter().map(|t| t.transform_ref(f)).collect(),
387            ))),
388            TypeKind::Record(fields) => Type(Arc::new(TypeKind::Record(
389                fields
390                    .iter()
391                    .map(|(s, t)| (s.clone(), t.transform_ref(f)))
392                    .collect(),
393            ))),
394        }
395    }
396}
397
398impl AsRef<TypeKind> for Type {
399    fn as_ref(&self) -> &TypeKind {
400        self.0.as_ref()
401    }
402}
403
404impl std::ops::Deref for Type {
405    type Target = TypeKind;
406
407    fn deref(&self) -> &Self::Target {
408        &self.0
409    }
410}
411
412impl Display for Type {
413    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
414        match self.as_ref() {
415            TypeKind::Var(tv) => match &tv.name {
416                Some(name) => write!(f, "'{}", name),
417                None => write!(f, "t{}", tv.id),
418            },
419            TypeKind::Con(c) => write!(f, "{}", c.name_str()),
420            TypeKind::App(l, r) => {
421                // Internally `Result` is represented as `Result err ok` so it can be partially
422                // applied as `Result err` for HKTs (Functor/Monad/etc).
423                //
424                // User-facing syntax is `Result ok err` (Rust-style), so render the fully
425                // applied form with swapped arguments.
426                if let TypeKind::App(head, err) = l.as_ref()
427                    && matches!(
428                        head.as_ref(),
429                        TypeKind::Con(c)
430                            if c.is_builtin(BuiltinTypeId::Result) && c.arity() == 2
431                    )
432                {
433                    return write!(f, "(Result {} {})", r, err);
434                }
435                write!(f, "({} {})", l, r)
436            }
437            TypeKind::Fun(a, b) => write!(f, "({} -> {})", a, b),
438            TypeKind::Tuple(elems) => {
439                write!(f, "(")?;
440                for (i, t) in elems.iter().enumerate() {
441                    write!(f, "{}", t)?;
442                    if i + 1 < elems.len() {
443                        write!(f, ", ")?;
444                    }
445                }
446                write!(f, ")")
447            }
448            TypeKind::Record(fields) => {
449                write!(f, "{{")?;
450                for (i, (name, ty)) in fields.iter().enumerate() {
451                    write!(f, "{}: {}", name, ty)?;
452                    if i + 1 < fields.len() {
453                        write!(f, ", ")?;
454                    }
455                }
456                write!(f, "}}")
457            }
458        }
459    }
460}
461
462#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
463pub struct Predicate {
464    pub class: Symbol,
465    pub typ: Type,
466}
467
468impl Predicate {
469    pub fn new(class: impl AsRef<str>, typ: Type) -> Self {
470        Self {
471            class: Symbol::intern(class.as_ref()),
472            typ,
473        }
474    }
475}
476
477#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
478pub struct Scheme {
479    pub vars: Vec<TypeVar>,
480    pub preds: Vec<Predicate>,
481    pub typ: Type,
482}
483
484impl Scheme {
485    pub fn new(vars: Vec<TypeVar>, preds: Vec<Predicate>, typ: Type) -> Self {
486        Self { vars, preds, typ }
487    }
488}
489
490pub trait Types: Sized {
491    fn apply(&self, s: &Subst) -> Self;
492    fn ftv(&self) -> BTreeSet<TypeVarId>;
493}
494
495impl Types for Type {
496    fn apply(&self, s: &Subst) -> Self {
497        self.apply_with_change(s).0
498    }
499
500    fn ftv(&self) -> BTreeSet<TypeVarId> {
501        let mut out = BTreeSet::new();
502        let mut stack: Vec<&Type> = vec![self];
503        while let Some(t) = stack.pop() {
504            match t.as_ref() {
505                TypeKind::Var(tv) => {
506                    out.insert(tv.id);
507                }
508                TypeKind::Con(_) => {}
509                TypeKind::App(l, r) => {
510                    stack.push(l);
511                    stack.push(r);
512                }
513                TypeKind::Fun(a, b) => {
514                    stack.push(a);
515                    stack.push(b);
516                }
517                TypeKind::Tuple(ts) => {
518                    for t in ts {
519                        stack.push(t);
520                    }
521                }
522                TypeKind::Record(fields) => {
523                    for (_, ty) in fields {
524                        stack.push(ty);
525                    }
526                }
527            }
528        }
529        out
530    }
531}
532
533impl Types for Predicate {
534    fn apply(&self, s: &Subst) -> Self {
535        Predicate {
536            class: self.class.clone(),
537            typ: self.typ.apply(s),
538        }
539    }
540
541    fn ftv(&self) -> BTreeSet<TypeVarId> {
542        self.typ.ftv()
543    }
544}
545
546impl Types for Scheme {
547    fn apply(&self, s: &Subst) -> Self {
548        let mut s_pruned = Subst::new_sync();
549        for (k, v) in s.iter() {
550            if !self.vars.iter().any(|var| var.id == *k) {
551                s_pruned = s_pruned.insert(*k, v.clone());
552            }
553        }
554        Scheme::new(
555            self.vars.clone(),
556            self.preds.iter().map(|p| p.apply(&s_pruned)).collect(),
557            self.typ.apply(&s_pruned),
558        )
559    }
560
561    fn ftv(&self) -> BTreeSet<TypeVarId> {
562        let mut ftv = self.typ.ftv();
563        for p in &self.preds {
564            ftv.extend(p.ftv());
565        }
566        for v in &self.vars {
567            ftv.remove(&v.id);
568        }
569        ftv
570    }
571}
572
573impl<T: Types> Types for Vec<T> {
574    fn apply(&self, s: &Subst) -> Self {
575        self.iter().map(|t| t.apply(s)).collect()
576    }
577
578    fn ftv(&self) -> BTreeSet<TypeVarId> {
579        self.iter().flat_map(Types::ftv).collect()
580    }
581}
582
583#[derive(Clone, Debug, PartialEq)]
584pub struct TypedExpr {
585    pub typ: Type,
586    pub kind: Arc<TypedExprKind>,
587}
588
589struct TypedTailAppFrame {
590    head: Arc<TypedExpr>,
591    prefix_args: Vec<(Type, Arc<TypedExpr>)>,
592    tail_result_type: Type,
593}
594
595fn collect_typed_app_chain(expr: &TypedExpr) -> (Arc<TypedExpr>, Vec<(Type, Arc<TypedExpr>)>) {
596    let mut args = Vec::new();
597    let mut cur = expr;
598    while let TypedExprKind::App(f, x) = cur.kind.as_ref() {
599        args.push((cur.typ.clone(), Arc::clone(x)));
600        cur = f.as_ref();
601    }
602    args.reverse();
603    (Arc::new(cur.clone()), args)
604}
605
606fn collect_typed_tail_app_chain(
607    expr: &TypedExpr,
608) -> Option<(Arc<TypedExpr>, Vec<TypedTailAppFrame>)> {
609    let mut frames = Vec::new();
610    let mut cur = Arc::new(expr.clone());
611    while matches!(cur.kind.as_ref(), TypedExprKind::App(..)) {
612        let (head, mut args) = collect_typed_app_chain(cur.as_ref());
613        let Some((tail_result_type, tail)) = args.pop() else {
614            break;
615        };
616        if !matches!(tail.kind.as_ref(), TypedExprKind::App(..)) {
617            break;
618        }
619        frames.push(TypedTailAppFrame {
620            head,
621            prefix_args: args,
622            tail_result_type,
623        });
624        cur = tail;
625    }
626    (!frames.is_empty()).then_some((cur, frames))
627}
628
629fn typed_drop_placeholder() -> Arc<TypedExpr> {
630    Arc::new(TypedExpr::new(Type::tuple(vec![]), TypedExprKind::Hole))
631}
632
633fn drain_typed_expr_kind(kind: &mut TypedExprKind, stack: &mut Vec<Arc<TypedExpr>>) {
634    match kind {
635        TypedExprKind::Tuple(elems) | TypedExprKind::List(elems) => {
636            stack.extend(mem::take(elems));
637        }
638        TypedExprKind::Dict(kvs) => {
639            stack.extend(mem::take(kvs).into_values());
640        }
641        TypedExprKind::RecordUpdate { base, updates } => {
642            stack.push(mem::replace(base, typed_drop_placeholder()));
643            stack.extend(mem::take(updates).into_values());
644        }
645        TypedExprKind::App(f, x) => {
646            stack.push(mem::replace(f, typed_drop_placeholder()));
647            stack.push(mem::replace(x, typed_drop_placeholder()));
648        }
649        TypedExprKind::Project { expr, .. } => {
650            stack.push(mem::replace(expr, typed_drop_placeholder()));
651        }
652        TypedExprKind::Lam { body, .. } => {
653            stack.push(mem::replace(body, typed_drop_placeholder()));
654        }
655        TypedExprKind::Let { def, body, .. } => {
656            stack.push(mem::replace(def, typed_drop_placeholder()));
657            stack.push(mem::replace(body, typed_drop_placeholder()));
658        }
659        TypedExprKind::LetRec { bindings, body } => {
660            for (_name, def) in mem::take(bindings) {
661                stack.push(def);
662            }
663            stack.push(mem::replace(body, typed_drop_placeholder()));
664        }
665        TypedExprKind::Ite {
666            cond,
667            then_expr,
668            else_expr,
669        } => {
670            stack.push(mem::replace(cond, typed_drop_placeholder()));
671            stack.push(mem::replace(then_expr, typed_drop_placeholder()));
672            stack.push(mem::replace(else_expr, typed_drop_placeholder()));
673        }
674        TypedExprKind::Match { scrutinee, arms } => {
675            stack.push(mem::replace(scrutinee, typed_drop_placeholder()));
676            for (_pat, arm) in mem::take(arms) {
677                stack.push(arm);
678            }
679        }
680        TypedExprKind::Bool(..)
681        | TypedExprKind::Uint(..)
682        | TypedExprKind::Int(..)
683        | TypedExprKind::Float(..)
684        | TypedExprKind::String(..)
685        | TypedExprKind::Uuid(..)
686        | TypedExprKind::DateTime(..)
687        | TypedExprKind::Hole
688        | TypedExprKind::Var { .. } => {}
689    }
690}
691
692impl Drop for TypedExpr {
693    fn drop(&mut self) {
694        let Some(kind) = Arc::get_mut(&mut self.kind) else {
695            return;
696        };
697        let mut stack = Vec::new();
698        drain_typed_expr_kind(kind, &mut stack);
699        while let Some(mut expr) = stack.pop() {
700            let Some(expr) = Arc::get_mut(&mut expr) else {
701                continue;
702            };
703            let Some(kind) = Arc::get_mut(&mut expr.kind) else {
704                continue;
705            };
706            drain_typed_expr_kind(kind, &mut stack);
707        }
708    }
709}
710
711impl TypedExpr {
712    pub fn new(typ: Type, kind: TypedExprKind) -> Self {
713        Self {
714            typ,
715            kind: Arc::new(kind),
716        }
717    }
718
719    pub fn apply(&self, s: &Subst) -> Self {
720        // TODO: This still allocates a transformed expression tree. That may
721        // become too expensive for hot polymorphic apply paths once evaluator
722        // frames retain shared typed AST nodes.
723        match self.kind.as_ref() {
724            TypedExprKind::Lam { .. } => {
725                let mut params: Vec<(Symbol, Type)> = Vec::new();
726                let mut cur = self;
727                while let TypedExprKind::Lam { param, body } = cur.kind.as_ref() {
728                    params.push((param.clone(), cur.typ.apply(s)));
729                    cur = body.as_ref();
730                }
731                let mut out = cur.apply(s);
732                for (param, typ) in params.into_iter().rev() {
733                    out = TypedExpr::new(
734                        typ,
735                        TypedExprKind::Lam {
736                            param,
737                            body: Arc::new(out),
738                        },
739                    );
740                }
741                return out;
742            }
743            TypedExprKind::App(..) => {
744                if let Some((leaf, frames)) = collect_typed_tail_app_chain(self) {
745                    let mut out = leaf.apply(s);
746                    for frame in frames.into_iter().rev() {
747                        let mut typed = frame.head.apply(s);
748                        for (typ, arg) in frame.prefix_args {
749                            typed = TypedExpr::new(
750                                typ.apply(s),
751                                TypedExprKind::App(Arc::new(typed), Arc::new(arg.apply(s))),
752                            );
753                        }
754                        out = TypedExpr::new(
755                            frame.tail_result_type.apply(s),
756                            TypedExprKind::App(Arc::new(typed), Arc::new(out)),
757                        );
758                    }
759                    return out;
760                }
761
762                let mut apps: Vec<(Type, Arc<TypedExpr>)> = Vec::new();
763                let mut cur = self;
764                while let TypedExprKind::App(f, x) = cur.kind.as_ref() {
765                    apps.push((cur.typ.apply(s), Arc::clone(x)));
766                    cur = f.as_ref();
767                }
768                let mut out = cur.apply(s);
769                for (typ, arg) in apps.into_iter().rev() {
770                    out = TypedExpr::new(
771                        typ,
772                        TypedExprKind::App(Arc::new(out), Arc::new(arg.apply(s))),
773                    );
774                }
775                return out;
776            }
777            _ => {}
778        }
779
780        let typ = self.typ.apply(s);
781        let kind = match self.kind.as_ref() {
782            TypedExprKind::Bool(v) => TypedExprKind::Bool(*v),
783            TypedExprKind::Uint(v) => TypedExprKind::Uint(*v),
784            TypedExprKind::Int(v) => TypedExprKind::Int(*v),
785            TypedExprKind::Float(v) => TypedExprKind::Float(*v),
786            TypedExprKind::String(v) => TypedExprKind::String(v.clone()),
787            TypedExprKind::Uuid(v) => TypedExprKind::Uuid(*v),
788            TypedExprKind::DateTime(v) => TypedExprKind::DateTime(*v),
789            TypedExprKind::Hole => TypedExprKind::Hole,
790            TypedExprKind::Tuple(elems) => {
791                TypedExprKind::Tuple(elems.iter().map(|e| Arc::new(e.apply(s))).collect())
792            }
793            TypedExprKind::List(elems) => {
794                TypedExprKind::List(elems.iter().map(|e| Arc::new(e.apply(s))).collect())
795            }
796            TypedExprKind::Dict(kvs) => {
797                let mut out = BTreeMap::new();
798                for (k, v) in kvs {
799                    out.insert(k.clone(), Arc::new(v.apply(s)));
800                }
801                TypedExprKind::Dict(out)
802            }
803            TypedExprKind::RecordUpdate { base, updates } => {
804                let mut out = BTreeMap::new();
805                for (k, v) in updates {
806                    out.insert(k.clone(), Arc::new(v.apply(s)));
807                }
808                TypedExprKind::RecordUpdate {
809                    base: Arc::new(base.apply(s)),
810                    updates: out,
811                }
812            }
813            TypedExprKind::Var { name, overloads } => TypedExprKind::Var {
814                name: name.clone(),
815                overloads: overloads.iter().map(|t| t.apply(s)).collect(),
816            },
817            TypedExprKind::App(f, x) => {
818                TypedExprKind::App(Arc::new(f.apply(s)), Arc::new(x.apply(s)))
819            }
820            TypedExprKind::Project { expr, field } => TypedExprKind::Project {
821                expr: Arc::new(expr.apply(s)),
822                field: field.clone(),
823            },
824            TypedExprKind::Lam { param, body } => TypedExprKind::Lam {
825                param: param.clone(),
826                body: Arc::new(body.apply(s)),
827            },
828            TypedExprKind::Let { name, def, body } => TypedExprKind::Let {
829                name: name.clone(),
830                def: Arc::new(def.apply(s)),
831                body: Arc::new(body.apply(s)),
832            },
833            TypedExprKind::LetRec { bindings, body } => TypedExprKind::LetRec {
834                bindings: bindings
835                    .iter()
836                    .map(|(name, def)| (name.clone(), Arc::new(def.apply(s))))
837                    .collect(),
838                body: Arc::new(body.apply(s)),
839            },
840            TypedExprKind::Ite {
841                cond,
842                then_expr,
843                else_expr,
844            } => TypedExprKind::Ite {
845                cond: Arc::new(cond.apply(s)),
846                then_expr: Arc::new(then_expr.apply(s)),
847                else_expr: Arc::new(else_expr.apply(s)),
848            },
849            TypedExprKind::Match { scrutinee, arms } => TypedExprKind::Match {
850                scrutinee: Arc::new(scrutinee.apply(s)),
851                arms: arms
852                    .iter()
853                    .map(|(p, e)| (p.clone(), Arc::new(e.apply(s))))
854                    .collect(),
855            },
856        };
857        TypedExpr::new(typ, kind)
858    }
859}
860
861#[derive(Clone, Debug, PartialEq)]
862pub enum TypedExprKind {
863    Bool(bool),
864    Uint(u64),
865    Int(i64),
866    Float(f64),
867    String(String),
868    Uuid(Uuid),
869    DateTime(DateTime<Utc>),
870    Hole,
871    Tuple(Vec<Arc<TypedExpr>>),
872    List(Vec<Arc<TypedExpr>>),
873    Dict(BTreeMap<Symbol, Arc<TypedExpr>>),
874    RecordUpdate {
875        base: Arc<TypedExpr>,
876        updates: BTreeMap<Symbol, Arc<TypedExpr>>,
877    },
878    Var {
879        name: Symbol,
880        overloads: Vec<Type>,
881    },
882    App(Arc<TypedExpr>, Arc<TypedExpr>),
883    Project {
884        expr: Arc<TypedExpr>,
885        field: Symbol,
886    },
887    Lam {
888        param: Symbol,
889        body: Arc<TypedExpr>,
890    },
891    Let {
892        name: Symbol,
893        def: Arc<TypedExpr>,
894        body: Arc<TypedExpr>,
895    },
896    LetRec {
897        bindings: Vec<(Symbol, Arc<TypedExpr>)>,
898        body: Arc<TypedExpr>,
899    },
900    Ite {
901        cond: Arc<TypedExpr>,
902        then_expr: Arc<TypedExpr>,
903        else_expr: Arc<TypedExpr>,
904    },
905    Match {
906        scrutinee: Arc<TypedExpr>,
907        arms: Vec<(Pattern, Arc<TypedExpr>)>,
908    },
909}
910
911#[derive(Default, Debug, Clone)]
912pub struct TypeEnv {
913    pub values: HashTrieMapSync<Symbol, Vec<Scheme>>,
914}
915
916impl TypeEnv {
917    pub fn new() -> Self {
918        Self {
919            values: HashTrieMapSync::new_sync(),
920        }
921    }
922
923    pub fn extend(&mut self, name: Symbol, scheme: Scheme) {
924        self.values = self.values.insert(name, vec![scheme]);
925    }
926
927    pub fn extend_overload(&mut self, name: Symbol, scheme: Scheme) {
928        let mut schemes = self.values.get(&name).cloned().unwrap_or_default();
929        schemes.push(scheme);
930        self.values = self.values.insert(name, schemes);
931    }
932
933    pub fn remove(&mut self, name: &Symbol) {
934        self.values = self.values.remove(name);
935    }
936
937    pub fn lookup(&self, name: &Symbol) -> Option<&[Scheme]> {
938        self.values.get(name).map(|schemes| schemes.as_slice())
939    }
940}
941
942impl Types for TypeEnv {
943    fn apply(&self, s: &Subst) -> Self {
944        let mut values = HashTrieMapSync::new_sync();
945        for (k, v) in self.values.iter() {
946            let updated = v
947                .iter()
948                .map(|scheme| {
949                    // Most schemes in environments are monomorphic. Don't walk
950                    // and rebuild trees unless we actually have work to do.
951                    if scheme.vars.is_empty() && !subst_is_empty(s) {
952                        scheme.apply(s)
953                    } else {
954                        scheme.clone()
955                    }
956                })
957                .collect();
958            values = values.insert(k.clone(), updated);
959        }
960        TypeEnv { values }
961    }
962
963    fn ftv(&self) -> BTreeSet<TypeVarId> {
964        self.values
965            .iter()
966            .flat_map(|(_, schemes)| schemes.iter().flat_map(Types::ftv))
967            .collect()
968    }
969}
970
971/// A named type parameter for an ADT (e.g. `a` in `List a`).
972#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
973pub struct AdtParam {
974    pub name: Symbol,
975    pub var: TypeVar,
976}
977
978/// A single ADT variant with zero or more constructor arguments.
979#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
980pub struct AdtVariant {
981    pub name: Symbol,
982    pub args: Vec<Type>,
983}
984
985/// A type declaration for an algebraic data type.
986///
987/// This only describes the *type* surface (params + variants). It does not
988/// introduce any runtime values by itself. Runtime values are created by
989/// injecting constructor schemes into the environment (see `inject_adt`).
990#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
991pub struct AdtDecl {
992    pub name: Symbol,
993    pub params: Vec<AdtParam>,
994    pub variants: Vec<AdtVariant>,
995}
996
997impl AdtDecl {
998    pub fn new(name: &Symbol, param_names: &[Symbol], supply: &mut TypeVarSupply) -> Self {
999        let params = param_names
1000            .iter()
1001            .map(|p| AdtParam {
1002                name: p.clone(),
1003                var: supply.fresh(Some(p.clone())),
1004            })
1005            .collect();
1006        Self {
1007            name: name.clone(),
1008            params,
1009            variants: Vec::new(),
1010        }
1011    }
1012
1013    pub fn param_type(&self, name: &Symbol) -> Option<Type> {
1014        self.params
1015            .iter()
1016            .find(|p| &p.name == name)
1017            .map(|p| Type::var(p.var.clone()))
1018    }
1019
1020    pub fn add_variant(&mut self, name: Symbol, args: Vec<Type>) {
1021        self.variants.push(AdtVariant { name, args });
1022    }
1023
1024    pub fn result_type(&self) -> Type {
1025        let mut ty = Type::con(&self.name, self.params.len());
1026        for param in &self.params {
1027            ty = Type::app(ty, Type::var(param.var.clone()));
1028        }
1029        ty
1030    }
1031
1032    /// Build constructor schemes of the form:
1033    /// `C :: a1 -> a2 -> ... -> T params`.
1034    pub fn constructor_schemes(&self) -> Vec<(Symbol, Scheme)> {
1035        let result_ty = self.result_type();
1036        let vars: Vec<TypeVar> = self.params.iter().map(|p| p.var.clone()).collect();
1037        let mut out = Vec::new();
1038        for variant in &self.variants {
1039            let mut typ = result_ty.clone();
1040            for arg in variant.args.iter().rev() {
1041                typ = Type::fun(arg.clone(), typ);
1042            }
1043            out.push((variant.name.clone(), Scheme::new(vars.clone(), vec![], typ)));
1044        }
1045        out
1046    }
1047}
1048
1049/// Rust-side type metadata for values that can appear at a Rex boundary.
1050///
1051/// Implement this trait for any Rust type that appears in a typed host function
1052/// signature, a derived Rex ADT field, or other embedder-facing conversion
1053/// point. The returned [`Type`] is the Rex type that users see in signatures and
1054/// type errors.
1055///
1056/// Primitive Rust types such as integers, floats, `bool`, `String`, `Vec<T>`,
1057/// `Option<T>`, and `Result<T, E>` already implement `RexType`. For Rust structs
1058/// and enums that should be visible as Rex algebraic data types, prefer
1059/// `#[derive(rex::Rex)]`, which implements both `RexType` and [`RexAdt`].
1060pub trait RexType {
1061    /// Return the Rex type corresponding to `Self`.
1062    ///
1063    /// This type is used when Rex builds host function signatures, checks calls
1064    /// to native functions, and discovers the declarations needed for ADT
1065    /// registration.
1066    fn rex_type() -> Type;
1067
1068    /// Append Rex ADT declarations required by this type to `out`.
1069    ///
1070    /// The default implementation is intentionally empty, which is correct for
1071    /// primitive and leaf types that do not introduce Rex ADT declarations.
1072    /// Derived ADTs override this to collect declarations for the full acyclic
1073    /// family reachable from `Self` and then append their own declaration.
1074    ///
1075    /// Callers that register the family are responsible for ordering and
1076    /// validating the declarations before injection.
1077    fn collect_rex_family(_out: &mut Vec<AdtDecl>) -> Result<(), TypeError> {
1078        Ok(())
1079    }
1080}
1081
1082/// Rust-side declaration metadata for a type represented as a Rex ADT.
1083///
1084/// `RexAdt` extends [`RexType`] for Rust structs and enums that have a named Rex
1085/// algebraic data type declaration. The engine and module APIs use this trait to
1086/// register constructors and type declarations before Rex code constructs or
1087/// consumes values of the Rust type.
1088///
1089/// Most embedders should derive this with `#[derive(rex::Rex)]`. Manual
1090/// implementations are useful for hand-written bridges or types whose Rex shape
1091/// differs from their Rust fields.
1092pub trait RexAdt: RexType {
1093    /// Return the single Rex ADT declaration for `Self`.
1094    ///
1095    /// This should describe only the type represented by `Self`; dependencies
1096    /// belong in [`RexType::collect_rex_family`].
1097    fn rex_adt_decl() -> Result<AdtDecl, TypeError>;
1098
1099    /// Return the ADT family needed to register `Self`.
1100    ///
1101    /// The default implementation delegates to [`RexType::collect_rex_family`].
1102    /// Derived implementations of `collect_rex_family` include dependencies and
1103    /// `Self`; manual implementations can override this method when they need a
1104    /// custom family collection strategy.
1105    fn rex_adt_family() -> Result<Vec<AdtDecl>, TypeError> {
1106        let mut out = Vec::new();
1107        <Self as RexType>::collect_rex_family(&mut out)?;
1108        Ok(out)
1109    }
1110}
1111
1112impl RexType for bool {
1113    fn rex_type() -> Type {
1114        Type::builtin(BuiltinTypeId::Bool)
1115    }
1116}
1117
1118impl RexType for u8 {
1119    fn rex_type() -> Type {
1120        Type::builtin(BuiltinTypeId::U8)
1121    }
1122}
1123
1124impl RexType for u16 {
1125    fn rex_type() -> Type {
1126        Type::builtin(BuiltinTypeId::U16)
1127    }
1128}
1129
1130impl RexType for u32 {
1131    fn rex_type() -> Type {
1132        Type::builtin(BuiltinTypeId::U32)
1133    }
1134}
1135
1136impl RexType for u64 {
1137    fn rex_type() -> Type {
1138        Type::builtin(BuiltinTypeId::U64)
1139    }
1140}
1141
1142impl RexType for i8 {
1143    fn rex_type() -> Type {
1144        Type::builtin(BuiltinTypeId::I8)
1145    }
1146}
1147
1148impl RexType for i16 {
1149    fn rex_type() -> Type {
1150        Type::builtin(BuiltinTypeId::I16)
1151    }
1152}
1153
1154impl RexType for i32 {
1155    fn rex_type() -> Type {
1156        Type::builtin(BuiltinTypeId::I32)
1157    }
1158}
1159
1160impl RexType for i64 {
1161    fn rex_type() -> Type {
1162        Type::builtin(BuiltinTypeId::I64)
1163    }
1164}
1165
1166impl RexType for f32 {
1167    fn rex_type() -> Type {
1168        Type::builtin(BuiltinTypeId::F32)
1169    }
1170}
1171
1172impl RexType for f64 {
1173    fn rex_type() -> Type {
1174        Type::builtin(BuiltinTypeId::F64)
1175    }
1176}
1177
1178impl RexType for String {
1179    fn rex_type() -> Type {
1180        Type::builtin(BuiltinTypeId::String)
1181    }
1182}
1183
1184impl RexType for &str {
1185    fn rex_type() -> Type {
1186        Type::builtin(BuiltinTypeId::String)
1187    }
1188}
1189
1190impl RexType for Uuid {
1191    fn rex_type() -> Type {
1192        Type::builtin(BuiltinTypeId::Uuid)
1193    }
1194}
1195
1196impl RexType for DateTime<Utc> {
1197    fn rex_type() -> Type {
1198        Type::builtin(BuiltinTypeId::DateTime)
1199    }
1200}
1201
1202impl<T: RexType> RexType for Vec<T> {
1203    fn rex_type() -> Type {
1204        Type::app(Type::builtin(BuiltinTypeId::Array), T::rex_type())
1205    }
1206}
1207
1208impl<T: RexType> RexType for Option<T> {
1209    fn rex_type() -> Type {
1210        Type::app(Type::builtin(BuiltinTypeId::Option), T::rex_type())
1211    }
1212}
1213
1214impl<T: RexType, E: RexType> RexType for Result<T, E> {
1215    fn rex_type() -> Type {
1216        Type::app(
1217            Type::app(Type::builtin(BuiltinTypeId::Result), E::rex_type()),
1218            T::rex_type(),
1219        )
1220    }
1221}
1222
1223impl RexType for () {
1224    fn rex_type() -> Type {
1225        Type::tuple(vec![])
1226    }
1227}
1228
1229macro_rules! impl_tuple_rex_type {
1230    ($($name:ident),+) => {
1231        impl<$($name: RexType),+> RexType for ($($name,)+) {
1232            fn rex_type() -> Type {
1233                Type::tuple(vec![$($name::rex_type()),+])
1234            }
1235        }
1236    };
1237}
1238
1239impl_tuple_rex_type!(A0);
1240impl_tuple_rex_type!(A0, A1);
1241impl_tuple_rex_type!(A0, A1, A2);
1242impl_tuple_rex_type!(A0, A1, A2, A3);
1243impl_tuple_rex_type!(A0, A1, A2, A3, A4);
1244impl_tuple_rex_type!(A0, A1, A2, A3, A4, A5);
1245impl_tuple_rex_type!(A0, A1, A2, A3, A4, A5, A6);
1246impl_tuple_rex_type!(A0, A1, A2, A3, A4, A5, A6, A7);
1247
1248#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
1249pub struct Class {
1250    pub supers: Vec<Symbol>,
1251}
1252
1253impl Class {
1254    pub fn new(supers: Vec<Symbol>) -> Self {
1255        Self { supers }
1256    }
1257}
1258
1259#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
1260pub struct Instance {
1261    pub context: Vec<Predicate>,
1262    pub head: Predicate,
1263}
1264
1265impl Instance {
1266    pub fn new(context: Vec<Predicate>, head: Predicate) -> Self {
1267        Self { context, head }
1268    }
1269}
1270
1271#[derive(Default, Debug, Clone)]
1272pub struct ClassEnv {
1273    pub classes: BTreeMap<Symbol, Class>,
1274    pub instances: BTreeMap<Symbol, Vec<Instance>>,
1275}
1276
1277impl ClassEnv {
1278    pub fn new() -> Self {
1279        Self {
1280            classes: BTreeMap::new(),
1281            instances: BTreeMap::new(),
1282        }
1283    }
1284
1285    pub fn add_class(&mut self, name: Symbol, supers: Vec<Symbol>) {
1286        self.classes.insert(name, Class::new(supers));
1287    }
1288
1289    pub fn add_instance(&mut self, class: Symbol, inst: Instance) {
1290        self.instances.entry(class).or_default().push(inst);
1291    }
1292
1293    pub fn supers_of(&self, class: &Symbol) -> Vec<Symbol> {
1294        self.classes
1295            .get(class)
1296            .map(|c| c.supers.clone())
1297            .unwrap_or_default()
1298    }
1299}
1300
1301/// Collect all user-defined ADT constructors referenced by the provided types.
1302///
1303/// This walks each type recursively (including nested occurrences), returns a
1304/// deduplicated list of constructor heads, and rejects ambiguous constructor
1305/// names that appear with incompatible definitions.
1306///
1307/// The returned `Type`s are constructor heads (for example `Foo`), suitable
1308/// for passing to embedder utilities that derive `AdtDecl`s from type
1309/// constructors.
1310///
1311/// # Examples
1312///
1313/// ```rust,ignore
1314/// use rex_ts::{collect_adts_in_types, BuiltinTypeId, Type};
1315///
1316/// let types = vec![
1317///     Type::app(Type::user_con("Foo", 1), Type::builtin(BuiltinTypeId::I32)),
1318///     Type::fun(Type::user_con("Bar", 0), Type::user_con("Foo", 1)),
1319/// ];
1320///
1321/// let adts = collect_adts_in_types(types).unwrap();
1322/// assert_eq!(adts, vec![Type::user_con("Foo", 1), Type::user_con("Bar", 0)]);
1323/// ```
1324///
1325/// ```rust,ignore
1326/// use rex_ts::{collect_adts_in_types, Type};
1327///
1328/// let err = collect_adts_in_types(vec![
1329///     Type::user_con("Thing", 1),
1330///     Type::user_con("Thing", 2),
1331/// ])
1332/// .unwrap_err();
1333///
1334/// assert_eq!(err.conflicts.len(), 1);
1335/// assert_eq!(err.conflicts[0].name.as_ref(), "Thing");
1336/// ```
1337pub fn collect_adts_in_types(types: Vec<Type>) -> Result<Vec<Type>, CollectAdtsError> {
1338    let mut out = Vec::new();
1339    let mut seen = BTreeSet::new();
1340    let mut defs_by_name: BTreeMap<Symbol, Vec<Type>> = BTreeMap::new();
1341    for typ in &types {
1342        typ.for_each(|t| {
1343            if let TypeKind::Con(tc) = t.as_ref() {
1344                // Builtins are not embeddable ADT declarations.
1345                if let Some(name) = tc.user_name() {
1346                    let adt = Type::new(TypeKind::Con(tc.clone()));
1347                    if seen.insert(adt.clone()) {
1348                        out.push(adt.clone());
1349                    }
1350                    let defs = defs_by_name.entry(name.clone()).or_default();
1351                    if !defs.contains(&adt) {
1352                        defs.push(adt);
1353                    }
1354                }
1355            }
1356        });
1357    }
1358
1359    let conflicts: Vec<AdtConflict> = defs_by_name
1360        .into_iter()
1361        .filter_map(|(name, definitions)| {
1362            (definitions.len() > 1).then_some(AdtConflict { name, definitions })
1363        })
1364        .collect();
1365    if !conflicts.is_empty() {
1366        return Err(CollectAdtsError { conflicts });
1367    }
1368
1369    Ok(out)
1370}
1371
1372fn collect_adts_error_to_type(err: CollectAdtsError) -> TypeError {
1373    let details = err
1374        .conflicts
1375        .into_iter()
1376        .map(|conflict| {
1377            let defs = conflict
1378                .definitions
1379                .iter()
1380                .map(ToString::to_string)
1381                .collect::<Vec<_>>()
1382                .join(", ");
1383            format!("{}: [{defs}]", conflict.name)
1384        })
1385        .collect::<Vec<_>>()
1386        .join("; ");
1387    TypeError::Internal(format!(
1388        "conflicting ADT definitions discovered in input types: {details}"
1389    ))
1390}
1391
1392fn type_head_and_args_for_adt_family(typ: &Type) -> Result<(Symbol, usize, Vec<Type>), TypeError> {
1393    let mut args = Vec::new();
1394    let mut head = typ;
1395    while let TypeKind::App(f, arg) = head.as_ref() {
1396        args.push(arg.clone());
1397        head = f;
1398    }
1399    args.reverse();
1400
1401    let TypeKind::Con(con) = head.as_ref() else {
1402        return Err(TypeError::Internal(format!(
1403            "cannot build ADT declaration from non-constructor type `{typ}`"
1404        )));
1405    };
1406    if !args.is_empty() && args.len() != con.arity() {
1407        return Err(TypeError::Internal(format!(
1408            "constructor `{}` expected {} type arguments but got {} in `{typ}`",
1409            con.name_str(),
1410            con.arity(),
1411            args.len()
1412        )));
1413    }
1414    Ok((con.name(), con.arity(), args))
1415}
1416
1417fn type_head_for_adt_family(typ: &Type) -> Result<Type, TypeError> {
1418    let (name, arity, _args) = type_head_and_args_for_adt_family(typ)?;
1419    Ok(Type::con(name.as_ref(), arity))
1420}
1421
1422fn adt_shape(adt: &AdtDecl) -> String {
1423    let param_names: BTreeMap<_, _> = adt
1424        .params
1425        .iter()
1426        .enumerate()
1427        .map(|(idx, param)| (param.var.id, format!("t{idx}")))
1428        .collect();
1429    let mut variants = adt
1430        .variants
1431        .iter()
1432        .map(|variant| {
1433            let args = variant
1434                .args
1435                .iter()
1436                .map(|arg| normalize_type_for_shape(arg, &param_names))
1437                .collect::<Vec<_>>()
1438                .join(", ");
1439            format!("{}({args})", variant.name)
1440        })
1441        .collect::<Vec<_>>();
1442    variants.sort();
1443    format!("{}[{}]", adt.name, variants.join(" | "))
1444}
1445
1446fn normalize_type_for_shape(typ: &Type, param_names: &BTreeMap<usize, String>) -> String {
1447    match typ.as_ref() {
1448        TypeKind::Var(tv) => param_names
1449            .get(&tv.id)
1450            .cloned()
1451            .unwrap_or_else(|| format!("v{}", tv.id)),
1452        TypeKind::Con(con) => con.name_str().to_string(),
1453        TypeKind::App(fun, arg) => format!(
1454            "({} {})",
1455            normalize_type_for_shape(fun, param_names),
1456            normalize_type_for_shape(arg, param_names)
1457        ),
1458        TypeKind::Fun(arg, ret) => format!(
1459            "({} -> {})",
1460            normalize_type_for_shape(arg, param_names),
1461            normalize_type_for_shape(ret, param_names)
1462        ),
1463        TypeKind::Tuple(elems) => format!(
1464            "({})",
1465            elems
1466                .iter()
1467                .map(|elem| normalize_type_for_shape(elem, param_names))
1468                .collect::<Vec<_>>()
1469                .join(", ")
1470        ),
1471        TypeKind::Record(fields) => format!(
1472            "{{{}}}",
1473            fields
1474                .iter()
1475                .map(|(name, typ)| format!(
1476                    "{name}: {}",
1477                    normalize_type_for_shape(typ, param_names)
1478                ))
1479                .collect::<Vec<_>>()
1480                .join(", ")
1481        ),
1482    }
1483}
1484
1485fn adt_shape_eq(left: &AdtDecl, right: &AdtDecl) -> bool {
1486    adt_shape(left) == adt_shape(right)
1487}
1488
1489fn adt_direct_dependencies(adt: &AdtDecl) -> Result<Vec<Type>, TypeError> {
1490    let types = adt
1491        .variants
1492        .iter()
1493        .flat_map(|variant| variant.args.iter().cloned())
1494        .collect::<Vec<_>>();
1495    let deps = collect_adts_in_types(types).map_err(collect_adts_error_to_type)?;
1496    deps.into_iter()
1497        .map(|typ| type_head_for_adt_family(&typ))
1498        .collect()
1499}
1500
1501/// Order a family of algebraic data type declarations for registration.
1502///
1503/// An ADT family is the root ADT an embedder wants to expose plus the
1504/// user-defined ADTs that appear in its variant fields, recursively. For
1505/// example, if Rust type `Workflow` contains a `Step` field and `Step` contains
1506/// a `Resource` field, then `Workflow`, `Step`, and `Resource` form the family
1507/// that must be registered together.
1508///
1509/// This function deduplicates identical declarations, rejects conflicting
1510/// declarations for the same ADT name, rejects dependency cycles, and returns
1511/// the declarations in dependency order so nested ADTs are registered before
1512/// the ADTs that refer to them.
1513pub fn order_adt_family(adts: Vec<AdtDecl>) -> Result<Vec<AdtDecl>, TypeError> {
1514    let mut unique = BTreeMap::new();
1515    for adt in adts {
1516        match unique.get(&adt.name) {
1517            Some(existing) if adt_shape_eq(existing, &adt) => {}
1518            Some(existing) => {
1519                return Err(TypeError::Internal(format!(
1520                    "conflicting ADT family definitions for `{}`: {} vs {}",
1521                    adt.name,
1522                    adt_shape(existing),
1523                    adt_shape(&adt)
1524                )));
1525            }
1526            None => {
1527                unique.insert(adt.name.clone(), adt);
1528            }
1529        }
1530    }
1531
1532    let mut visiting = Vec::<Symbol>::new();
1533    let mut visited = BTreeSet::<Symbol>::new();
1534    let mut ordered = Vec::<AdtDecl>::new();
1535
1536    fn visit(
1537        name: &Symbol,
1538        unique: &BTreeMap<Symbol, AdtDecl>,
1539        visiting: &mut Vec<Symbol>,
1540        visited: &mut BTreeSet<Symbol>,
1541        ordered: &mut Vec<AdtDecl>,
1542    ) -> Result<(), TypeError> {
1543        if visited.contains(name) {
1544            return Ok(());
1545        }
1546        if let Some(idx) = visiting.iter().position(|current| current == name) {
1547            let mut cycle = visiting[idx..]
1548                .iter()
1549                .map(ToString::to_string)
1550                .collect::<Vec<_>>();
1551            cycle.push(name.to_string());
1552            return Err(TypeError::Internal(format!(
1553                "cyclic ADT auto-registration is not supported yet: {}",
1554                cycle.join(" -> ")
1555            )));
1556        }
1557
1558        let adt = unique
1559            .get(name)
1560            .ok_or_else(|| TypeError::Internal(format!("missing ADT `{name}` during ordering")))?;
1561        visiting.push(name.clone());
1562        for dep in adt_direct_dependencies(adt)? {
1563            let dep_head = type_head_for_adt_family(&dep)?;
1564            let TypeKind::Con(dep_con) = dep_head.as_ref() else {
1565                return Err(TypeError::Internal(format!(
1566                    "dependency head for `{name}` was not a constructor"
1567                )));
1568            };
1569            if let Some(name) = dep_con.user_name()
1570                && unique.contains_key(name)
1571            {
1572                visit(name, unique, visiting, visited, ordered)?;
1573            }
1574        }
1575        visiting.pop();
1576        visited.insert(name.clone());
1577        ordered.push(adt.clone());
1578        Ok(())
1579    }
1580
1581    let mut names = unique.keys().cloned().collect::<Vec<_>>();
1582    names.sort();
1583    for name in names {
1584        visit(&name, &unique, &mut visiting, &mut visited, &mut ordered)?;
1585    }
1586    Ok(ordered)
1587}