Skip to main content

rex_typesystem/
typesystem.rs

1//! Core type system implementation for Rex.
2
3use std::collections::{BTreeMap, BTreeSet};
4use std::sync::Arc;
5
6use rex_ast::expr::{
7    ClassDecl, ClassMethodSig, Decl, DeclareFnDecl, Expr, FnDecl, InstanceDecl, InstanceMethodImpl,
8    Scope, Symbol, TypeConstraint, TypeDecl, TypeExpr, sym,
9};
10use rex_lexer::span::Span;
11
12use crate::prelude;
13
14pub use crate::{
15    // inference::{
16    //     infer,
17    //     infer_typed,
18    //     infer_with_gas,
19    //     infer_typed_with_gas,
20    // },
21    unification::{Subst, compose_subst, unify},
22};
23
24use crate::{
25    error::TypeError,
26    inference::infer_typed,
27    types::{
28        AdtDecl, BuiltinTypeId, ClassEnv, Instance, Predicate, Scheme, Type, TypeEnv, TypeKind,
29        TypeVar, TypeVarId, TypedExpr, Types,
30    },
31    unification::scheme_compatible,
32};
33
34fn format_constraints_referencing_vars(preds: &[Predicate], vars: &[TypeVarId]) -> String {
35    if vars.is_empty() {
36        return String::new();
37    }
38    let var_set: BTreeSet<TypeVarId> = vars.iter().copied().collect();
39    let mut parts = Vec::new();
40    for pred in preds {
41        let ftv = pred.ftv();
42        if ftv.iter().any(|v| var_set.contains(v)) {
43            parts.push(format!("{} {}", pred.class, pred.typ));
44        }
45    }
46    if parts.is_empty() {
47        // Fallback: show all constraints if the filtering logic misses something.
48        for pred in preds {
49            parts.push(format!("{} {}", pred.class, pred.typ));
50        }
51    }
52    parts.join(", ")
53}
54
55pub(crate) fn reject_ambiguous_scheme(scheme: &Scheme) -> Result<(), TypeError> {
56    // Only reject *quantified* ambiguous variables. Variables free in the
57    // environment are allowed to appear only in predicates, since they can be
58    // determined by outer context.
59    let quantified: BTreeSet<TypeVarId> = scheme.vars.iter().map(|v| v.id).collect();
60    if quantified.is_empty() {
61        return Ok(());
62    }
63
64    let typ_ftv = scheme.typ.ftv();
65    let mut vars = BTreeSet::new();
66    for pred in &scheme.preds {
67        let TypeKind::Var(tv) = pred.typ.as_ref() else {
68            continue;
69        };
70        if quantified.contains(&tv.id) && !typ_ftv.contains(&tv.id) {
71            vars.insert(tv.id);
72        }
73    }
74
75    if vars.is_empty() {
76        return Ok(());
77    }
78    let mut vars: Vec<TypeVarId> = vars.into_iter().collect();
79    vars.sort_unstable();
80    let constraints = format_constraints_referencing_vars(&scheme.preds, &vars);
81    Err(TypeError::AmbiguousTypeVars { vars, constraints })
82}
83
84#[derive(Clone, Copy, Debug)]
85pub struct TypeSystemLimits {
86    pub max_infer_depth: Option<usize>,
87}
88
89impl TypeSystemLimits {
90    pub fn unlimited() -> Self {
91        Self {
92            max_infer_depth: None,
93        }
94    }
95
96    pub fn safe_defaults() -> Self {
97        Self {
98            max_infer_depth: Some(4096),
99        }
100    }
101}
102
103impl Default for TypeSystemLimits {
104    fn default() -> Self {
105        Self::safe_defaults()
106    }
107}
108
109fn superclass_closure(class_env: &ClassEnv, given: &[Predicate]) -> Vec<Predicate> {
110    let mut closure: Vec<Predicate> = given.to_vec();
111    let mut i = 0;
112    while i < closure.len() {
113        let p = closure[i].clone();
114        for sup in class_env.supers_of(&p.class) {
115            closure.push(Predicate::new(sup, p.typ.clone()));
116        }
117        i += 1;
118    }
119    closure
120}
121
122fn check_non_ground_predicates_declared(
123    class_env: &ClassEnv,
124    declared: &[Predicate],
125    inferred: &[Predicate],
126) -> Result<(), TypeError> {
127    // Compare by a stable, user-facing rendering (`Default a`, `Foldable t`, ...),
128    // rather than `TypeVarId`, so signature variables that only appear in
129    // predicates (and thus aren't related by unification) still match up.
130    let closure = superclass_closure(class_env, declared);
131    let closure_keys: BTreeSet<String> = closure
132        .iter()
133        .map(|p| format!("{} {}", p.class, p.typ))
134        .collect();
135    let mut missing = Vec::new();
136    for pred in inferred {
137        if pred.typ.ftv().is_empty() {
138            continue;
139        }
140        let key = format!("{} {}", pred.class, pred.typ);
141        if !closure_keys.contains(&key) {
142            missing.push(key);
143        }
144    }
145
146    missing.sort();
147    missing.dedup();
148    if missing.is_empty() {
149        return Ok(());
150    }
151    Err(TypeError::MissingConstraints {
152        constraints: missing.join(", "),
153    })
154}
155
156fn type_term_remaining_arity(ty: &Type) -> Option<usize> {
157    match ty.as_ref() {
158        TypeKind::Var(_) => None,
159        TypeKind::Con(tc) => Some(tc.arity),
160        TypeKind::App(l, _) => {
161            let a = type_term_remaining_arity(l)?;
162            Some(a.saturating_sub(1))
163        }
164        TypeKind::Fun(..) | TypeKind::Tuple(..) | TypeKind::Record(..) => Some(0),
165    }
166}
167
168fn max_head_app_arity_for_var(ty: &Type, var_id: TypeVarId) -> usize {
169    let mut max_arity = 0usize;
170    let mut stack: Vec<&Type> = vec![ty];
171    while let Some(t) = stack.pop() {
172        match t.as_ref() {
173            TypeKind::Var(_) | TypeKind::Con(_) => {}
174            TypeKind::App(l, r) => {
175                // Record the full application depth at this node.
176                let mut head = t;
177                let mut args = 0usize;
178                while let TypeKind::App(left, _) = head.as_ref() {
179                    args += 1;
180                    head = left;
181                }
182                if let TypeKind::Var(tv) = head.as_ref()
183                    && tv.id == var_id
184                {
185                    max_arity = max_arity.max(args);
186                }
187                stack.push(l);
188                stack.push(r);
189            }
190            TypeKind::Fun(a, b) => {
191                stack.push(a);
192                stack.push(b);
193            }
194            TypeKind::Tuple(ts) => {
195                for t in ts {
196                    stack.push(t);
197                }
198            }
199            TypeKind::Record(fields) => {
200                for (_, t) in fields {
201                    stack.push(t);
202                }
203            }
204        }
205    }
206    max_arity
207}
208
209#[derive(Default, Debug, Clone)]
210pub struct TypeVarSupply {
211    counter: TypeVarId,
212}
213
214impl TypeVarSupply {
215    pub fn new() -> Self {
216        Self { counter: 0 }
217    }
218
219    pub fn fresh(&mut self, name_hint: impl Into<Option<Symbol>>) -> TypeVar {
220        let tv = TypeVar::new(self.counter, name_hint.into());
221        self.counter += 1;
222        tv
223    }
224}
225
226pub(crate) fn is_integral_literal_expr(expr: &Expr) -> bool {
227    matches!(expr, Expr::Int(..) | Expr::Uint(..))
228}
229
230/// Turn a monotype `typ` (plus constraints `preds`) into a polymorphic `Scheme`
231/// by quantifying over the type variables not free in `env`.
232pub fn generalize(env: &TypeEnv, preds: Vec<Predicate>, typ: Type) -> Scheme {
233    let mut vars: Vec<TypeVar> = typ
234        .ftv()
235        .union(&preds.ftv())
236        .copied()
237        .collect::<BTreeSet<_>>()
238        .difference(&env.ftv())
239        .cloned()
240        .map(|id| TypeVar::new(id, None))
241        .collect();
242    vars.sort_by_key(|v| v.id);
243    Scheme::new(vars, preds, typ)
244}
245
246pub fn instantiate(scheme: &Scheme, supply: &mut TypeVarSupply) -> (Vec<Predicate>, Type) {
247    // Instantiate replaces all quantified variables with fresh unification
248    // variables, preserving the original name as a debugging hint.
249    let mut subst = Subst::new_sync();
250    for v in &scheme.vars {
251        subst = subst.insert(v.id, Type::var(supply.fresh(v.name.clone())));
252    }
253    (scheme.preds.apply(&subst), scheme.typ.apply(&subst))
254}
255
256pub fn entails(
257    class_env: &ClassEnv,
258    given: &[Predicate],
259    pred: &Predicate,
260) -> Result<bool, TypeError> {
261    // Expand given with superclasses.
262    let mut closure: Vec<Predicate> = given.to_vec();
263    let mut i = 0;
264    while i < closure.len() {
265        let p = closure[i].clone();
266        for sup in class_env.supers_of(&p.class) {
267            closure.push(Predicate::new(sup, p.typ.clone()));
268        }
269        i += 1;
270    }
271
272    if closure
273        .iter()
274        .any(|p| p.class == pred.class && p.typ == pred.typ)
275    {
276        return Ok(true);
277    }
278
279    if !class_env.classes.contains_key(&pred.class) {
280        return Err(TypeError::UnknownClass(pred.class.clone()));
281    }
282
283    if let Some(instances) = class_env.instances.get(&pred.class) {
284        for inst in instances {
285            if let Ok(s) = unify(&inst.head.typ, &pred.typ) {
286                let ctx = inst.context.apply(&s);
287                if ctx
288                    .iter()
289                    .all(|c| entails(class_env, &closure, c).unwrap_or(false))
290                {
291                    return Ok(true);
292                }
293            }
294        }
295    }
296    Ok(false)
297}
298
299#[derive(Default, Debug, Clone)]
300pub struct TypeSystem {
301    pub env: TypeEnv,
302    pub classes: ClassEnv,
303    pub adts: BTreeMap<Symbol, AdtDecl>,
304    pub class_info: BTreeMap<Symbol, ClassInfo>,
305    pub class_methods: BTreeMap<Symbol, ClassMethodInfo>,
306    /// Names introduced by `declare fn` (forward declarations).
307    ///
308    /// These are placeholders in the type environment and must not block a later
309    /// real definition (e.g. `fn foo = ...` or host/CLI injection).
310    pub declared_values: BTreeSet<Symbol>,
311    pub supply: TypeVarSupply,
312    pub limits: TypeSystemLimits,
313}
314
315/// Semantic information about a type class declaration, derived from Rex source.
316///
317/// Design notes (WARM):
318/// - We keep this explicit and data-oriented: it makes review easy and keeps costs visible.
319/// - Rex represents multi-parameter classes by encoding the parameters as a tuple in the
320///   single `Predicate.typ` slot. For a unary class `C a` the predicate is `C a`. For a
321///   binary class `C t a` the predicate is `C (t, a)`, etc.
322/// - This keeps the runtime/type-inference machinery simple: instance matching is still
323///   “unify the predicate types”, and no separate arity tracking is needed.
324#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
325pub struct ClassInfo {
326    pub name: Symbol,
327    pub params: Vec<Symbol>,
328    pub supers: Vec<Symbol>,
329    pub methods: BTreeMap<Symbol, Scheme>,
330}
331
332#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
333pub struct ClassMethodInfo {
334    pub class: Symbol,
335    pub scheme: Scheme,
336}
337
338#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
339pub struct PreparedInstanceDecl {
340    pub span: Span,
341    pub class: Symbol,
342    pub head: Type,
343    pub context: Vec<Predicate>,
344}
345
346impl TypeSystem {
347    pub fn new() -> Self {
348        Self {
349            env: TypeEnv::new(),
350            classes: ClassEnv::new(),
351            adts: BTreeMap::new(),
352            class_info: BTreeMap::new(),
353            class_methods: BTreeMap::new(),
354            declared_values: BTreeSet::new(),
355            supply: TypeVarSupply::new(),
356            limits: TypeSystemLimits::default(),
357        }
358    }
359
360    pub fn fresh_type_var(&mut self, name: Option<Symbol>) -> TypeVar {
361        self.supply.fresh(name)
362    }
363
364    pub fn set_limits(&mut self, limits: TypeSystemLimits) {
365        self.limits = limits;
366    }
367
368    pub fn new_with_prelude() -> Result<Self, TypeError> {
369        let mut ts = TypeSystem::new();
370        prelude::build_prelude(&mut ts)?;
371        Ok(ts)
372    }
373
374    fn register_decl(&mut self, decl: &Decl) -> Result<(), TypeError> {
375        match decl {
376            Decl::Type(ty) => self.register_type_decl(ty),
377            Decl::Class(class_decl) => self.register_class_decl(class_decl),
378            Decl::Instance(inst_decl) => {
379                let _ = self.register_instance_decl(inst_decl)?;
380                Ok(())
381            }
382            Decl::Fn(fd) => self.register_fn_decls(std::slice::from_ref(fd)),
383            Decl::DeclareFn(fd) => self.inject_declare_fn_decl(fd),
384            Decl::Import(..) => Ok(()),
385        }
386    }
387
388    pub fn register_decls(&mut self, decls: &[Decl]) -> Result<(), TypeError> {
389        let mut pending_fns: Vec<FnDecl> = Vec::new();
390        for decl in decls {
391            if let Decl::Fn(fd) = decl {
392                pending_fns.push(fd.clone());
393                continue;
394            }
395
396            if !pending_fns.is_empty() {
397                self.register_fn_decls(&pending_fns)?;
398                pending_fns.clear();
399            }
400
401            self.register_decl(decl)?;
402        }
403        if !pending_fns.is_empty() {
404            self.register_fn_decls(&pending_fns)?;
405        }
406        Ok(())
407    }
408
409    pub fn add_value(&mut self, name: impl AsRef<str>, scheme: Scheme) {
410        let name = sym(name.as_ref());
411        self.declared_values.remove(&name);
412        self.env.extend(name, scheme);
413    }
414
415    pub fn add_overload(&mut self, name: impl AsRef<str>, scheme: Scheme) {
416        let name = sym(name.as_ref());
417        self.declared_values.remove(&name);
418        self.env.extend_overload(name, scheme);
419    }
420
421    pub fn register_instance(&mut self, class: impl AsRef<str>, inst: Instance) {
422        self.classes.add_instance(sym(class.as_ref()), inst);
423    }
424
425    pub fn register_class_decl(&mut self, decl: &ClassDecl) -> Result<(), TypeError> {
426        let span = decl.span;
427        (|| {
428            // Classes are global, and Rex does not support reopening/merging them.
429            // Allowing that would be a long-term maintenance hazard: it creates
430            // spooky-action-at-a-distance across modules and makes reviews harder.
431            if self.class_info.contains_key(&decl.name)
432                || self.classes.classes.contains_key(&decl.name)
433            {
434                return Err(TypeError::DuplicateClass(decl.name.clone()));
435            }
436            if decl.params.is_empty() {
437                return Err(TypeError::InvalidClassArity {
438                    class: decl.name.clone(),
439                    got: decl.params.len(),
440                });
441            }
442            let params = decl.params.clone();
443
444            // Register the superclass relationships in the class environment.
445            //
446            // We only accept `<= C param` style superclasses for now. Anything
447            // fancier would require storing type-level relationships in `ClassEnv`,
448            // which Rex does not currently model.
449            let mut supers = Vec::with_capacity(decl.supers.len());
450            if !decl.supers.is_empty() && params.len() != 1 {
451                return Err(TypeError::UnsupportedExpr(
452                    "multi-parameter classes cannot declare superclasses yet",
453                ));
454            }
455            for sup in &decl.supers {
456                let mut vars = BTreeMap::new();
457                let param = params[0].clone();
458                let param_tv = self.supply.fresh(Some(param.clone()));
459                vars.insert(param, param_tv.clone());
460                let sup_ty = type_from_annotation_expr_vars(
461                    &self.adts,
462                    &sup.typ,
463                    &mut vars,
464                    &mut self.supply,
465                )?;
466                if sup_ty != Type::var(param_tv) {
467                    return Err(TypeError::UnsupportedExpr(
468                        "superclass constraints must be of the form `<= C a`",
469                    ));
470                }
471                supers.push(sup.class.to_dotted_symbol());
472            }
473
474            self.classes.add_class(decl.name.clone(), supers.clone());
475
476            let mut methods = BTreeMap::new();
477            for ClassMethodSig { name, typ } in &decl.methods {
478                if self.env.lookup(name).is_some() || self.class_methods.contains_key(name) {
479                    return Err(TypeError::DuplicateClassMethod(name.clone()));
480                }
481
482                let mut vars: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
483                let mut param_tvs: Vec<TypeVar> = Vec::with_capacity(params.len());
484                for param in &params {
485                    let tv = self.supply.fresh(Some(param.clone()));
486                    vars.insert(param.clone(), tv.clone());
487                    param_tvs.push(tv);
488                }
489
490                let ty =
491                    type_from_annotation_expr_vars(&self.adts, typ, &mut vars, &mut self.supply)?;
492
493                let mut scheme_vars: Vec<TypeVar> = vars.values().cloned().collect();
494                scheme_vars.sort_by_key(|tv| tv.id);
495                scheme_vars.dedup_by_key(|tv| tv.id);
496
497                let class_pred = Predicate {
498                    class: decl.name.clone(),
499                    typ: if param_tvs.len() == 1 {
500                        Type::var(param_tvs[0].clone())
501                    } else {
502                        Type::tuple(param_tvs.into_iter().map(Type::var).collect())
503                    },
504                };
505                let scheme = Scheme::new(scheme_vars, vec![class_pred], ty);
506
507                self.env.extend(name.clone(), scheme.clone());
508                self.class_methods.insert(
509                    name.clone(),
510                    ClassMethodInfo {
511                        class: decl.name.clone(),
512                        scheme: scheme.clone(),
513                    },
514                );
515                methods.insert(name.clone(), scheme);
516            }
517
518            self.class_info.insert(
519                decl.name.clone(),
520                ClassInfo {
521                    name: decl.name.clone(),
522                    params,
523                    supers,
524                    methods,
525                },
526            );
527            Ok(())
528        })()
529        .map_err(|err| err.with_span(&span))
530    }
531
532    pub fn register_instance_decl(
533        &mut self,
534        decl: &InstanceDecl,
535    ) -> Result<PreparedInstanceDecl, TypeError> {
536        let span = decl.span;
537        (|| {
538            let class = decl.class.clone();
539            if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
540                return Err(TypeError::UnknownClass(class));
541            }
542
543            let mut vars: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
544            let head = type_from_annotation_expr_vars(
545                &self.adts,
546                &decl.head,
547                &mut vars,
548                &mut self.supply,
549            )?;
550            let context = predicates_from_constraints(
551                &self.adts,
552                &decl.context,
553                &mut vars,
554                &mut self.supply,
555            )?;
556
557            let inst = Instance::new(
558                context.clone(),
559                Predicate {
560                    class: decl.class.clone(),
561                    typ: head.clone(),
562                },
563            );
564
565            // Validate method list against the class declaration if present.
566            if let Some(info) = self.class_info.get(&decl.class) {
567                for method in &decl.methods {
568                    if !info.methods.contains_key(&method.name) {
569                        return Err(TypeError::UnknownInstanceMethod {
570                            class: decl.class.clone(),
571                            method: method.name.clone(),
572                        });
573                    }
574                }
575                for method_name in info.methods.keys() {
576                    if !decl.methods.iter().any(|m| &m.name == method_name) {
577                        return Err(TypeError::MissingInstanceMethod {
578                            class: decl.class.clone(),
579                            method: method_name.clone(),
580                        });
581                    }
582                }
583            }
584
585            self.classes.add_instance(decl.class.clone(), inst);
586            Ok(PreparedInstanceDecl {
587                span,
588                class: decl.class.clone(),
589                head,
590                context,
591            })
592        })()
593        .map_err(|err| err.with_span(&span))
594    }
595
596    pub fn prepare_instance_decl(
597        &mut self,
598        decl: &InstanceDecl,
599    ) -> Result<PreparedInstanceDecl, TypeError> {
600        let span = decl.span;
601        (|| {
602            let class = decl.class.clone();
603            if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
604                return Err(TypeError::UnknownClass(class));
605            }
606
607            let mut vars: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
608            let head = type_from_annotation_expr_vars(
609                &self.adts,
610                &decl.head,
611                &mut vars,
612                &mut self.supply,
613            )?;
614            let context = predicates_from_constraints(
615                &self.adts,
616                &decl.context,
617                &mut vars,
618                &mut self.supply,
619            )?;
620
621            // Validate method list against the class declaration if present.
622            if let Some(info) = self.class_info.get(&decl.class) {
623                for method in &decl.methods {
624                    if !info.methods.contains_key(&method.name) {
625                        return Err(TypeError::UnknownInstanceMethod {
626                            class: decl.class.clone(),
627                            method: method.name.clone(),
628                        });
629                    }
630                }
631                for method_name in info.methods.keys() {
632                    if !decl.methods.iter().any(|m| &m.name == method_name) {
633                        return Err(TypeError::MissingInstanceMethod {
634                            class: decl.class.clone(),
635                            method: method_name.clone(),
636                        });
637                    }
638                }
639            }
640
641            Ok(PreparedInstanceDecl {
642                span,
643                class: decl.class.clone(),
644                head,
645                context,
646            })
647        })()
648        .map_err(|err| err.with_span(&span))
649    }
650
651    pub fn register_fn_decls(&mut self, decls: &[FnDecl]) -> Result<(), TypeError> {
652        if decls.is_empty() {
653            return Ok(());
654        }
655
656        let saved_env = self.env.clone();
657        let saved_declared = self.declared_values.clone();
658
659        let result: Result<(), TypeError> = (|| {
660            #[derive(Clone)]
661            struct FnInfo {
662                decl: FnDecl,
663                expected: Type,
664                declared_preds: Vec<Predicate>,
665                scheme: Scheme,
666                ann_vars: BTreeMap<Symbol, TypeVar>,
667            }
668
669            let mut infos: Vec<FnInfo> = Vec::with_capacity(decls.len());
670            let mut seen_names = BTreeSet::new();
671
672            for decl in decls {
673                let span = decl.span;
674                let info = (|| {
675                    let name = &decl.name.name;
676                    if !seen_names.insert(name.clone()) {
677                        return Err(TypeError::DuplicateValue(name.clone()));
678                    }
679
680                    if self.env.lookup(name).is_some() {
681                        if self.declared_values.remove(name) {
682                            // A forward declaration should not block the real definition.
683                            self.env.remove(name);
684                        } else {
685                            return Err(TypeError::DuplicateValue(name.clone()));
686                        }
687                    }
688
689                    let mut sig = decl.ret.clone();
690                    for (_, ann) in decl.params.iter().rev() {
691                        let span = Span::from_begin_end(ann.span().begin, sig.span().end);
692                        sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
693                    }
694
695                    let mut ann_vars: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
696                    let expected = type_from_annotation_expr_vars(
697                        &self.adts,
698                        &sig,
699                        &mut ann_vars,
700                        &mut self.supply,
701                    )?;
702                    let declared_preds = predicates_from_constraints(
703                        &self.adts,
704                        &decl.constraints,
705                        &mut ann_vars,
706                        &mut self.supply,
707                    )?;
708
709                    // Validate that declared constraints are well-formed.
710                    let var_arities: BTreeMap<TypeVarId, usize> = ann_vars
711                        .values()
712                        .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
713                        .collect();
714                    for pred in &declared_preds {
715                        let _ = entails(&self.classes, &[], pred)?;
716                        let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
717                        else {
718                            continue;
719                        };
720                        let args: Vec<Type> = if expected_arities.len() == 1 {
721                            vec![pred.typ.clone()]
722                        } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
723                            if parts.len() != expected_arities.len() {
724                                continue;
725                            }
726                            parts.clone()
727                        } else {
728                            continue;
729                        };
730
731                        for (arg, expected_arity) in
732                            args.iter().zip(expected_arities.iter().copied())
733                        {
734                            let got =
735                                type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
736                                    TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
737                                    _ => None,
738                                });
739                            let Some(got) = got else {
740                                continue;
741                            };
742                            if got != expected_arity {
743                                return Err(TypeError::KindMismatch {
744                                    class: pred.class.clone(),
745                                    expected: expected_arity,
746                                    got,
747                                    typ: arg.to_string(),
748                                });
749                            }
750                        }
751                    }
752
753                    let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
754                    vars.sort_by_key(|v| v.id);
755                    let scheme = Scheme::new(vars, declared_preds.clone(), expected.clone());
756                    reject_ambiguous_scheme(&scheme)?;
757
758                    Ok(FnInfo {
759                        decl: decl.clone(),
760                        expected,
761                        declared_preds,
762                        scheme,
763                        ann_vars,
764                    })
765                })();
766
767                infos.push(info.map_err(|err| err.with_span(&span))?);
768            }
769
770            // Seed environment with all declared signatures first so fn bodies
771            // can reference each other recursively (let-rec semantics).
772            for info in &infos {
773                self.env
774                    .extend(info.decl.name.name.clone(), info.scheme.clone());
775            }
776
777            for info in infos {
778                let span = info.decl.span;
779                let mut lam_body = info.decl.body.clone();
780                let mut lam_end = lam_body.span().end;
781                for (param, ann) in info.decl.params.iter().rev() {
782                    let lam_constraints = Vec::new();
783                    let span = Span::from_begin_end(param.span.begin, lam_end);
784                    lam_body = Arc::new(Expr::Lam(
785                        span,
786                        Scope::new_sync(),
787                        param.clone(),
788                        Some(ann.clone()),
789                        lam_constraints,
790                        lam_body,
791                    ));
792                    lam_end = lam_body.span().end;
793                }
794
795                let (typed, preds, inferred) = infer_typed(self, lam_body.as_ref())?;
796                let s = unify(&inferred, &info.expected)?;
797                let preds = preds.apply(&s);
798                let inferred = inferred.apply(&s);
799                let declared_preds = info.declared_preds.apply(&s);
800                let expected = info.expected.apply(&s);
801
802                // Keep kind checks aligned with existing `inject_fn_decl` logic.
803                let var_arities: BTreeMap<TypeVarId, usize> = info
804                    .ann_vars
805                    .values()
806                    .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
807                    .collect();
808                for pred in &declared_preds {
809                    let _ = entails(&self.classes, &[], pred)?;
810                    let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
811                    else {
812                        continue;
813                    };
814                    let args: Vec<Type> = if expected_arities.len() == 1 {
815                        vec![pred.typ.clone()]
816                    } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
817                        if parts.len() != expected_arities.len() {
818                            continue;
819                        }
820                        parts.clone()
821                    } else {
822                        continue;
823                    };
824
825                    for (arg, expected_arity) in args.iter().zip(expected_arities.iter().copied()) {
826                        let got = type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
827                            TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
828                            _ => None,
829                        });
830                        let Some(got) = got else {
831                            continue;
832                        };
833                        if got != expected_arity {
834                            let err = TypeError::KindMismatch {
835                                class: pred.class.clone(),
836                                expected: expected_arity,
837                                got,
838                                typ: arg.to_string(),
839                            };
840                            return Err(err.with_span(&span));
841                        }
842                    }
843                }
844
845                check_non_ground_predicates_declared(&self.classes, &declared_preds, &preds)
846                    .map_err(|err| err.with_span(&span))?;
847
848                let _ = inferred;
849                let _ = typed;
850            }
851
852            Ok(())
853        })();
854
855        if result.is_err() {
856            self.env = saved_env;
857            self.declared_values = saved_declared;
858        }
859        result
860    }
861
862    pub fn inject_declare_fn_decl(&mut self, decl: &DeclareFnDecl) -> Result<(), TypeError> {
863        let span = decl.span;
864        (|| {
865            // Build the declared signature type.
866            let mut sig = decl.ret.clone();
867            for (_, ann) in decl.params.iter().rev() {
868                let span = Span::from_begin_end(ann.span().begin, sig.span().end);
869                sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
870            }
871
872            let mut ann_vars: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
873            let expected =
874                type_from_annotation_expr_vars(&self.adts, &sig, &mut ann_vars, &mut self.supply)?;
875            let declared_preds = predicates_from_constraints(
876                &self.adts,
877                &decl.constraints,
878                &mut ann_vars,
879                &mut self.supply,
880            )?;
881
882            let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
883            vars.sort_by_key(|v| v.id);
884            let scheme = Scheme::new(vars, declared_preds, expected);
885            reject_ambiguous_scheme(&scheme)?;
886
887            // Validate referenced classes exist (and are spelled correctly).
888            for pred in &scheme.preds {
889                let _ = entails(&self.classes, &[], pred)?;
890            }
891
892            let name = &decl.name.name;
893
894            // If there is already a real definition (prelude/host/`fn`), treat
895            // `declare fn` as documentation only and ignore it.
896            if self.env.lookup(name).is_some() && !self.declared_values.contains(name) {
897                return Ok(());
898            }
899
900            if let Some(existing) = self.env.lookup(name) {
901                if existing.iter().any(|s| scheme_compatible(s, &scheme)) {
902                    return Ok(());
903                }
904                return Err(TypeError::DuplicateValue(decl.name.name.clone()));
905            }
906
907            self.env.extend(decl.name.name.clone(), scheme);
908            self.declared_values.insert(decl.name.name.clone());
909            Ok(())
910        })()
911        .map_err(|err| err.with_span(&span))
912    }
913
914    pub fn instantiate_class_method_for_head(
915        &mut self,
916        class: &Symbol,
917        method: &Symbol,
918        head: &Type,
919    ) -> Result<Type, TypeError> {
920        let info = self
921            .class_info
922            .get(class)
923            .ok_or_else(|| TypeError::UnknownClass(class.clone()))?;
924        let scheme = info
925            .methods
926            .get(method)
927            .ok_or_else(|| TypeError::UnknownInstanceMethod {
928                class: class.clone(),
929                method: method.clone(),
930            })?;
931
932        let (preds, typ) = instantiate(scheme, &mut self.supply);
933        let class_pred =
934            preds
935                .iter()
936                .find(|p| &p.class == class)
937                .ok_or(TypeError::UnsupportedExpr(
938                    "class method scheme missing class predicate",
939                ))?;
940        let s = unify(&class_pred.typ, head)?;
941        Ok(typ.apply(&s))
942    }
943
944    pub fn typecheck_instance_method(
945        &mut self,
946        prepared: &PreparedInstanceDecl,
947        method: &InstanceMethodImpl,
948    ) -> Result<TypedExpr, TypeError> {
949        let expected =
950            self.instantiate_class_method_for_head(&prepared.class, &method.name, &prepared.head)?;
951        let (typed, preds, actual) = infer_typed(self, method.body.as_ref())?;
952        let s = unify(&actual, &expected)?;
953        let typed = typed.apply(&s);
954        let preds = preds.apply(&s);
955
956        // The only legal “given” constraints inside an instance method are the
957        // instance context (plus superclass closure, plus the instance head
958        // itself). We do *not* allow instance
959        // search for non-ground constraints here, because that would be unsound:
960        // a type variable would unify with any concrete instance head.
961        let mut given = prepared.context.clone();
962
963        // Allow recursive instance methods (e.g. `Eq (List a)` calling `(==)`
964        // on the tail). This is dictionary recursion, not instance search.
965        given.push(Predicate::new(
966            prepared.class.clone(),
967            prepared.head.clone(),
968        ));
969        let mut i = 0;
970        while i < given.len() {
971            let p = given[i].clone();
972            for sup in self.classes.supers_of(&p.class) {
973                given.push(Predicate::new(sup, p.typ.clone()));
974            }
975            i += 1;
976        }
977
978        for pred in &preds {
979            if pred.typ.ftv().is_empty() {
980                if !entails(&self.classes, &given, pred)? {
981                    return Err(TypeError::NoInstance(
982                        pred.class.clone(),
983                        pred.typ.to_string(),
984                    ));
985                }
986            } else if !given
987                .iter()
988                .any(|p| p.class == pred.class && p.typ == pred.typ)
989            {
990                return Err(TypeError::MissingInstanceConstraint {
991                    method: method.name.clone(),
992                    class: pred.class.clone(),
993                    typ: pred.typ.to_string(),
994                });
995            }
996        }
997
998        Ok(typed)
999    }
1000
1001    /// Register constructor schemes for an ADT in the type environment.
1002    /// This makes constructors (e.g. `Some`, `None`, `Ok`, `Err`) available
1003    /// to the type checker as normal values.
1004    pub fn register_adt(&mut self, adt: &AdtDecl) {
1005        self.adts.insert(adt.name.clone(), adt.clone());
1006        for (name, scheme) in adt.constructor_schemes() {
1007            self.register_value_scheme(&name, scheme);
1008        }
1009    }
1010
1011    pub fn adt_from_decl(&mut self, decl: &TypeDecl) -> Result<AdtDecl, TypeError> {
1012        let mut adt = AdtDecl::new(&decl.name, &decl.params, &mut self.supply);
1013        let mut param_map: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
1014        for param in &adt.params {
1015            param_map.insert(param.name.clone(), param.var.clone());
1016        }
1017
1018        for variant in &decl.variants {
1019            let mut args = Vec::new();
1020            for arg in &variant.args {
1021                let ty = self.type_from_expr(decl, &param_map, arg)?;
1022                args.push(ty);
1023            }
1024            adt.add_variant(variant.name.clone(), args);
1025        }
1026        Ok(adt)
1027    }
1028
1029    pub fn register_type_decl(&mut self, decl: &TypeDecl) -> Result<(), TypeError> {
1030        if BuiltinTypeId::from_symbol(&decl.name).is_some() {
1031            return Err(TypeError::ReservedTypeName(decl.name.clone()));
1032        }
1033        let adt = self.adt_from_decl(decl)?;
1034        self.register_adt(&adt);
1035        Ok(())
1036    }
1037
1038    fn type_from_expr(
1039        &mut self,
1040        decl: &TypeDecl,
1041        params: &BTreeMap<Symbol, TypeVar>,
1042        expr: &TypeExpr,
1043    ) -> Result<Type, TypeError> {
1044        let span = *expr.span();
1045        let res = (|| match expr {
1046            TypeExpr::Name(_, name) => {
1047                let name_sym = name.to_dotted_symbol();
1048                if let Some(tv) = params.get(&name_sym) {
1049                    Ok(Type::var(tv.clone()))
1050                } else {
1051                    let name = normalize_type_name(&name_sym);
1052                    if let Some(arity) = self.type_arity(decl, &name) {
1053                        Ok(Type::con(name, arity))
1054                    } else {
1055                        Err(TypeError::UnknownTypeName(name))
1056                    }
1057                }
1058            }
1059            TypeExpr::App(_, fun, arg) => {
1060                let fty = self.type_from_expr(decl, params, fun)?;
1061                let aty = self.type_from_expr(decl, params, arg)?;
1062                Ok(type_app_with_result_syntax(fty, aty))
1063            }
1064            TypeExpr::Fun(_, arg, ret) => {
1065                let arg_ty = self.type_from_expr(decl, params, arg)?;
1066                let ret_ty = self.type_from_expr(decl, params, ret)?;
1067                Ok(Type::fun(arg_ty, ret_ty))
1068            }
1069            TypeExpr::Tuple(_, elems) => {
1070                let mut out = Vec::new();
1071                for elem in elems {
1072                    out.push(self.type_from_expr(decl, params, elem)?);
1073                }
1074                Ok(Type::tuple(out))
1075            }
1076            TypeExpr::Record(_, fields) => {
1077                let mut out = Vec::new();
1078                for (name, ty) in fields {
1079                    out.push((name.clone(), self.type_from_expr(decl, params, ty)?));
1080                }
1081                Ok(Type::record(out))
1082            }
1083        })();
1084        res.map_err(|err| err.with_span(&span))
1085    }
1086
1087    fn type_arity(&self, decl: &TypeDecl, name: &Symbol) -> Option<usize> {
1088        if &decl.name == name {
1089            return Some(decl.params.len());
1090        }
1091        if let Some(adt) = self.adts.get(name) {
1092            return Some(adt.params.len());
1093        }
1094        BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
1095    }
1096
1097    fn register_value_scheme(&mut self, name: &Symbol, scheme: Scheme) {
1098        match self.env.lookup(name) {
1099            None => self.env.extend(name.clone(), scheme),
1100            Some(existing) => {
1101                if existing.iter().any(|s| unify(&s.typ, &scheme.typ).is_ok()) {
1102                    return;
1103                }
1104                self.env.extend_overload(name.clone(), scheme);
1105            }
1106        }
1107    }
1108
1109    fn expected_class_param_arities(&self, class: &Symbol) -> Option<Vec<usize>> {
1110        let info = self.class_info.get(class)?;
1111        let mut out = vec![0usize; info.params.len()];
1112        for scheme in info.methods.values() {
1113            for (idx, param) in info.params.iter().enumerate() {
1114                let Some(tv) = scheme.vars.iter().find(|v| v.name.as_ref() == Some(param)) else {
1115                    continue;
1116                };
1117                out[idx] = out[idx].max(max_head_app_arity_for_var(&scheme.typ, tv.id));
1118            }
1119        }
1120        Some(out)
1121    }
1122
1123    fn check_predicate_kind(&self, pred: &Predicate) -> Result<(), TypeError> {
1124        let Some(expected) = self.expected_class_param_arities(&pred.class) else {
1125            // Host-injected classes (via Rust API) won't have `class_info`.
1126            return Ok(());
1127        };
1128
1129        let args: Vec<Type> = if expected.len() == 1 {
1130            vec![pred.typ.clone()]
1131        } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
1132            if parts.len() != expected.len() {
1133                return Ok(());
1134            }
1135            parts.clone()
1136        } else {
1137            return Ok(());
1138        };
1139
1140        for (arg, expected_arity) in args.iter().zip(expected.iter().copied()) {
1141            let Some(got) = type_term_remaining_arity(arg) else {
1142                // If we can't determine the arity (e.g. a bare type var), skip:
1143                // call sites may fix it up, and Rex does not currently do full
1144                // kind inference.
1145                continue;
1146            };
1147            if got != expected_arity {
1148                return Err(TypeError::KindMismatch {
1149                    class: pred.class.clone(),
1150                    expected: expected_arity,
1151                    got,
1152                    typ: arg.to_string(),
1153                });
1154            }
1155        }
1156        Ok(())
1157    }
1158
1159    pub(crate) fn check_predicate_kinds(&self, preds: &[Predicate]) -> Result<(), TypeError> {
1160        for pred in preds {
1161            self.check_predicate_kind(pred)?;
1162        }
1163        Ok(())
1164    }
1165}
1166
1167pub(crate) fn type_from_annotation_expr(
1168    adts: &BTreeMap<Symbol, AdtDecl>,
1169    expr: &TypeExpr,
1170) -> Result<Type, TypeError> {
1171    let span = *expr.span();
1172    let res = (|| match expr {
1173        TypeExpr::Name(_, name) => {
1174            let name = normalize_type_name(&name.to_dotted_symbol());
1175            match annotation_type_arity(adts, &name) {
1176                Some(arity) => Ok(Type::con(name, arity)),
1177                None => Err(TypeError::UnknownTypeName(name)),
1178            }
1179        }
1180        TypeExpr::App(_, fun, arg) => {
1181            let fty = type_from_annotation_expr(adts, fun)?;
1182            let aty = type_from_annotation_expr(adts, arg)?;
1183            Ok(type_app_with_result_syntax(fty, aty))
1184        }
1185        TypeExpr::Fun(_, arg, ret) => {
1186            let arg_ty = type_from_annotation_expr(adts, arg)?;
1187            let ret_ty = type_from_annotation_expr(adts, ret)?;
1188            Ok(Type::fun(arg_ty, ret_ty))
1189        }
1190        TypeExpr::Tuple(_, elems) => {
1191            let mut out = Vec::new();
1192            for elem in elems {
1193                out.push(type_from_annotation_expr(adts, elem)?);
1194            }
1195            Ok(Type::tuple(out))
1196        }
1197        TypeExpr::Record(_, fields) => {
1198            let mut out = Vec::new();
1199            for (name, ty) in fields {
1200                out.push((name.clone(), type_from_annotation_expr(adts, ty)?));
1201            }
1202            Ok(Type::record(out))
1203        }
1204    })();
1205    res.map_err(|err| err.with_span(&span))
1206}
1207
1208pub(crate) fn type_from_annotation_expr_vars(
1209    adts: &BTreeMap<Symbol, AdtDecl>,
1210    expr: &TypeExpr,
1211    vars: &mut BTreeMap<Symbol, TypeVar>,
1212    supply: &mut TypeVarSupply,
1213) -> Result<Type, TypeError> {
1214    let span = *expr.span();
1215    let res = (|| match expr {
1216        TypeExpr::Name(_, name) => {
1217            let name = normalize_type_name(&name.to_dotted_symbol());
1218            if let Some(arity) = annotation_type_arity(adts, &name) {
1219                Ok(Type::con(name, arity))
1220            } else if let Some(tv) = vars.get(&name) {
1221                Ok(Type::var(tv.clone()))
1222            } else {
1223                let is_upper = name
1224                    .chars()
1225                    .next()
1226                    .map(|c| c.is_uppercase())
1227                    .unwrap_or(false);
1228                if is_upper {
1229                    return Err(TypeError::UnknownTypeName(name));
1230                }
1231                let tv = supply.fresh(Some(name.clone()));
1232                vars.insert(name.clone(), tv.clone());
1233                Ok(Type::var(tv))
1234            }
1235        }
1236        TypeExpr::App(_, fun, arg) => {
1237            let fty = type_from_annotation_expr_vars(adts, fun, vars, supply)?;
1238            let aty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
1239            Ok(type_app_with_result_syntax(fty, aty))
1240        }
1241        TypeExpr::Fun(_, arg, ret) => {
1242            let arg_ty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
1243            let ret_ty = type_from_annotation_expr_vars(adts, ret, vars, supply)?;
1244            Ok(Type::fun(arg_ty, ret_ty))
1245        }
1246        TypeExpr::Tuple(_, elems) => {
1247            let mut out = Vec::new();
1248            for elem in elems {
1249                out.push(type_from_annotation_expr_vars(adts, elem, vars, supply)?);
1250            }
1251            Ok(Type::tuple(out))
1252        }
1253        TypeExpr::Record(_, fields) => {
1254            let mut out = Vec::new();
1255            for (name, ty) in fields {
1256                out.push((
1257                    name.clone(),
1258                    type_from_annotation_expr_vars(adts, ty, vars, supply)?,
1259                ));
1260            }
1261            Ok(Type::record(out))
1262        }
1263    })();
1264    res.map_err(|err| err.with_span(&span))
1265}
1266
1267fn annotation_type_arity(adts: &BTreeMap<Symbol, AdtDecl>, name: &Symbol) -> Option<usize> {
1268    if let Some(adt) = adts.get(name) {
1269        return Some(adt.params.len());
1270    }
1271    BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
1272}
1273
1274fn normalize_type_name(name: &Symbol) -> Symbol {
1275    if name.as_ref() == "str" {
1276        BuiltinTypeId::String.as_symbol()
1277    } else {
1278        name.clone()
1279    }
1280}
1281
1282fn type_app_with_result_syntax(fun: Type, arg: Type) -> Type {
1283    if let TypeKind::App(head, ok) = fun.as_ref()
1284        && matches!(
1285            head.as_ref(),
1286            TypeKind::Con(c)
1287                if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
1288        )
1289    {
1290        return Type::app(Type::app(head.clone(), arg), ok.clone());
1291    }
1292    Type::app(fun, arg)
1293}
1294
1295pub(crate) fn predicates_from_constraints(
1296    adts: &BTreeMap<Symbol, AdtDecl>,
1297    constraints: &[TypeConstraint],
1298    vars: &mut BTreeMap<Symbol, TypeVar>,
1299    supply: &mut TypeVarSupply,
1300) -> Result<Vec<Predicate>, TypeError> {
1301    let mut out = Vec::with_capacity(constraints.len());
1302    for constraint in constraints {
1303        let ty = type_from_annotation_expr_vars(adts, &constraint.typ, vars, supply)?;
1304        out.push(Predicate::new(constraint.class.as_ref(), ty));
1305    }
1306    Ok(out)
1307}