Skip to main content

rex_typesystem/
typesystem.rs

1//! Core type system implementation for Rex.
2
3use std::collections::{BTreeMap, BTreeSet};
4use std::sync::Arc;
5
6use rex_ast::{
7    ClassDecl, ClassMethodSig, Decl, DeclareFnDecl, Expr, FnDecl, InstanceDecl, InstanceMethodImpl,
8    Scope, Span, Symbol, TypeConstraint, TypeDecl, TypeExpr,
9};
10
11use crate::prelude;
12
13pub use crate::unification::{Subst, compose_subst, unify};
14
15use crate::{
16    error::TypeError,
17    inference::infer_typed,
18    types::{
19        AdtDecl, BuiltinTypeId, ClassEnv, Instance, Predicate, Scheme, Type, TypeEnv, TypeKind,
20        TypeVar, TypeVarId, TypedExpr, Types,
21    },
22    unification::scheme_compatible,
23};
24
25fn format_constraints_referencing_vars(preds: &[Predicate], vars: &[TypeVarId]) -> String {
26    if vars.is_empty() {
27        return String::new();
28    }
29    let var_set: BTreeSet<TypeVarId> = vars.iter().copied().collect();
30    let mut parts = Vec::new();
31    for pred in preds {
32        let ftv = pred.ftv();
33        if ftv.iter().any(|v| var_set.contains(v)) {
34            parts.push(format!("{} {}", pred.class, pred.typ));
35        }
36    }
37    if parts.is_empty() {
38        // Fallback: show all constraints if the filtering logic misses something.
39        for pred in preds {
40            parts.push(format!("{} {}", pred.class, pred.typ));
41        }
42    }
43    parts.join(", ")
44}
45
46pub(crate) fn reject_ambiguous_scheme(scheme: &Scheme) -> Result<(), TypeError> {
47    // Only reject *quantified* ambiguous variables. Variables free in the
48    // environment are allowed to appear only in predicates, since they can be
49    // determined by outer context.
50    let quantified: BTreeSet<TypeVarId> = scheme.vars.iter().map(|v| v.id).collect();
51    if quantified.is_empty() {
52        return Ok(());
53    }
54
55    let typ_ftv = scheme.typ.ftv();
56    let mut vars = BTreeSet::new();
57    for pred in &scheme.preds {
58        let TypeKind::Var(tv) = pred.typ.as_ref() else {
59            continue;
60        };
61        if quantified.contains(&tv.id) && !typ_ftv.contains(&tv.id) {
62            vars.insert(tv.id);
63        }
64    }
65
66    if vars.is_empty() {
67        return Ok(());
68    }
69    let mut vars: Vec<TypeVarId> = vars.into_iter().collect();
70    vars.sort_unstable();
71    let constraints = format_constraints_referencing_vars(&scheme.preds, &vars);
72    Err(TypeError::AmbiguousTypeVars { vars, constraints })
73}
74
75#[derive(Clone, Copy, Debug)]
76pub struct TypeSystemLimits {
77    pub max_infer_depth: Option<usize>,
78}
79
80impl TypeSystemLimits {
81    pub fn unlimited() -> Self {
82        Self {
83            max_infer_depth: None,
84        }
85    }
86
87    pub fn safe_defaults() -> Self {
88        Self {
89            max_infer_depth: Some(4096),
90        }
91    }
92}
93
94impl Default for TypeSystemLimits {
95    fn default() -> Self {
96        Self::safe_defaults()
97    }
98}
99
100fn superclass_closure(class_env: &ClassEnv, given: &[Predicate]) -> Vec<Predicate> {
101    let mut closure: Vec<Predicate> = given.to_vec();
102    let mut i = 0;
103    while i < closure.len() {
104        let p = closure[i].clone();
105        for sup in class_env.supers_of(&p.class) {
106            closure.push(Predicate::new(sup, p.typ.clone()));
107        }
108        i += 1;
109    }
110    closure
111}
112
113fn check_non_ground_predicates_declared(
114    class_env: &ClassEnv,
115    declared: &[Predicate],
116    inferred: &[Predicate],
117) -> Result<(), TypeError> {
118    // Compare by a stable, user-facing rendering (`Default a`, `Foldable t`, ...),
119    // rather than `TypeVarId`, so signature variables that only appear in
120    // predicates (and thus aren't related by unification) still match up.
121    let closure = superclass_closure(class_env, declared);
122    let closure_keys: BTreeSet<String> = closure
123        .iter()
124        .map(|p| format!("{} {}", p.class, p.typ))
125        .collect();
126    let mut missing = Vec::new();
127    for pred in inferred {
128        if pred.typ.ftv().is_empty() {
129            continue;
130        }
131        let key = format!("{} {}", pred.class, pred.typ);
132        if !closure_keys.contains(&key) {
133            missing.push(key);
134        }
135    }
136
137    missing.sort();
138    missing.dedup();
139    if missing.is_empty() {
140        return Ok(());
141    }
142    Err(TypeError::MissingConstraints {
143        constraints: missing.join(", "),
144    })
145}
146
147fn type_term_remaining_arity(ty: &Type) -> Option<usize> {
148    match ty.as_ref() {
149        TypeKind::Var(_) => None,
150        TypeKind::Con(tc) => Some(tc.arity()),
151        TypeKind::App(l, _) => {
152            let a = type_term_remaining_arity(l)?;
153            Some(a.saturating_sub(1))
154        }
155        TypeKind::Fun(..) | TypeKind::Tuple(..) | TypeKind::Record(..) => Some(0),
156    }
157}
158
159fn max_head_app_arity_for_var(ty: &Type, var_id: TypeVarId) -> usize {
160    let mut max_arity = 0usize;
161    let mut stack: Vec<&Type> = vec![ty];
162    while let Some(t) = stack.pop() {
163        match t.as_ref() {
164            TypeKind::Var(_) | TypeKind::Con(_) => {}
165            TypeKind::App(l, r) => {
166                // Record the full application depth at this node.
167                let mut head = t;
168                let mut args = 0usize;
169                while let TypeKind::App(left, _) = head.as_ref() {
170                    args += 1;
171                    head = left;
172                }
173                if let TypeKind::Var(tv) = head.as_ref()
174                    && tv.id == var_id
175                {
176                    max_arity = max_arity.max(args);
177                }
178                stack.push(l);
179                stack.push(r);
180            }
181            TypeKind::Fun(a, b) => {
182                stack.push(a);
183                stack.push(b);
184            }
185            TypeKind::Tuple(ts) => {
186                for t in ts {
187                    stack.push(t);
188                }
189            }
190            TypeKind::Record(fields) => {
191                for (_, t) in fields {
192                    stack.push(t);
193                }
194            }
195        }
196    }
197    max_arity
198}
199
200#[derive(Default, Debug, Clone)]
201pub struct TypeVarSupply {
202    counter: TypeVarId,
203}
204
205impl TypeVarSupply {
206    pub fn new() -> Self {
207        Self { counter: 0 }
208    }
209
210    pub fn fresh(&mut self, name_hint: impl Into<Option<Symbol>>) -> TypeVar {
211        let tv = TypeVar::new(self.counter, name_hint.into());
212        self.counter += 1;
213        tv
214    }
215}
216
217pub(crate) fn is_integral_literal_expr(expr: &Expr) -> bool {
218    matches!(expr, Expr::Int(..) | Expr::Uint(..))
219}
220
221/// Turn a monotype `typ` (plus constraints `preds`) into a polymorphic `Scheme`
222/// by quantifying over the type variables not free in `env`.
223pub fn generalize(env: &TypeEnv, preds: Vec<Predicate>, typ: Type) -> Scheme {
224    let mut vars: Vec<TypeVar> = typ
225        .ftv()
226        .union(&preds.ftv())
227        .copied()
228        .collect::<BTreeSet<_>>()
229        .difference(&env.ftv())
230        .cloned()
231        .map(|id| TypeVar::new(id, None))
232        .collect();
233    vars.sort_by_key(|v| v.id);
234    Scheme::new(vars, preds, typ)
235}
236
237pub fn instantiate(scheme: &Scheme, supply: &mut TypeVarSupply) -> (Vec<Predicate>, Type) {
238    // Instantiate replaces all quantified variables with fresh unification
239    // variables, preserving the original name as a debugging hint.
240    let mut subst = Subst::new_sync();
241    for v in &scheme.vars {
242        subst = subst.insert(v.id, Type::var(supply.fresh(v.name.clone())));
243    }
244    (scheme.preds.apply(&subst), scheme.typ.apply(&subst))
245}
246
247pub fn entails(
248    class_env: &ClassEnv,
249    given: &[Predicate],
250    pred: &Predicate,
251) -> Result<bool, TypeError> {
252    // Expand given with superclasses.
253    let mut closure: Vec<Predicate> = given.to_vec();
254    let mut i = 0;
255    while i < closure.len() {
256        let p = closure[i].clone();
257        for sup in class_env.supers_of(&p.class) {
258            closure.push(Predicate::new(sup, p.typ.clone()));
259        }
260        i += 1;
261    }
262
263    if closure
264        .iter()
265        .any(|p| p.class == pred.class && p.typ == pred.typ)
266    {
267        return Ok(true);
268    }
269
270    if !class_env.classes.contains_key(&pred.class) {
271        return Err(TypeError::UnknownClass(pred.class.clone()));
272    }
273
274    if let Some(instances) = class_env.instances.get(&pred.class) {
275        for inst in instances {
276            if let Ok(s) = unify(&inst.head.typ, &pred.typ) {
277                let ctx = inst.context.apply(&s);
278                if ctx
279                    .iter()
280                    .all(|c| entails(class_env, &closure, c).unwrap_or(false))
281                {
282                    return Ok(true);
283                }
284            }
285        }
286    }
287    Ok(false)
288}
289
290#[derive(Default, Debug, Clone)]
291pub struct TypeSystem {
292    pub env: TypeEnv,
293    pub classes: ClassEnv,
294    pub adts: BTreeMap<Symbol, AdtDecl>,
295    pub class_info: BTreeMap<Symbol, ClassInfo>,
296    pub class_methods: BTreeMap<Symbol, ClassMethodInfo>,
297    /// Names introduced by `declare fn` (forward declarations).
298    ///
299    /// These are placeholders in the type environment and must not block a later
300    /// real definition (e.g. `fn foo = ...` or host/CLI injection).
301    pub declared_values: BTreeSet<Symbol>,
302    pub supply: TypeVarSupply,
303    pub limits: TypeSystemLimits,
304}
305
306/// Semantic information about a type class declaration, derived from Rex source.
307///
308/// Design notes (WARM):
309/// - We keep this explicit and data-oriented: it makes review easy and keeps costs visible.
310/// - Rex represents multi-parameter classes by encoding the parameters as a tuple in the
311///   single `Predicate.typ` slot. For a unary class `C a` the predicate is `C a`. For a
312///   binary class `C t a` the predicate is `C (t, a)`, etc.
313/// - This keeps the runtime/type-inference machinery simple: instance matching is still
314///   “unify the predicate types”, and no separate arity tracking is needed.
315#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
316pub struct ClassInfo {
317    pub name: Symbol,
318    pub params: Vec<Symbol>,
319    pub supers: Vec<Symbol>,
320    pub methods: BTreeMap<Symbol, Scheme>,
321}
322
323#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
324pub struct ClassMethodInfo {
325    pub class: Symbol,
326    pub scheme: Scheme,
327}
328
329#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
330pub struct PreparedInstanceDecl {
331    pub span: Span,
332    pub class: Symbol,
333    pub head: Type,
334    pub context: Vec<Predicate>,
335}
336
337impl TypeSystem {
338    pub fn new() -> Self {
339        Self {
340            env: TypeEnv::new(),
341            classes: ClassEnv::new(),
342            adts: BTreeMap::new(),
343            class_info: BTreeMap::new(),
344            class_methods: BTreeMap::new(),
345            declared_values: BTreeSet::new(),
346            supply: TypeVarSupply::new(),
347            limits: TypeSystemLimits::default(),
348        }
349    }
350
351    pub fn fresh_type_var(&mut self, name: Option<Symbol>) -> TypeVar {
352        self.supply.fresh(name)
353    }
354
355    pub fn set_limits(&mut self, limits: TypeSystemLimits) {
356        self.limits = limits;
357    }
358
359    pub fn new_with_prelude() -> Result<Self, TypeError> {
360        let mut ts = TypeSystem::new();
361        prelude::build_prelude(&mut ts)?;
362        Ok(ts)
363    }
364
365    fn register_decl(&mut self, decl: &Decl) -> Result<(), TypeError> {
366        match decl {
367            Decl::Type(ty) => self.register_type_decl(ty),
368            Decl::Class(class_decl) => self.register_class_decl(class_decl),
369            Decl::Instance(inst_decl) => {
370                let _ = self.register_instance_decl(inst_decl)?;
371                Ok(())
372            }
373            Decl::Fn(fd) => self.register_fn_decls(std::slice::from_ref(fd)),
374            Decl::DeclareFn(fd) => self.inject_declare_fn_decl(fd),
375            Decl::Import(..) => Ok(()),
376        }
377    }
378
379    pub fn register_decls(&mut self, decls: &[Decl]) -> Result<(), TypeError> {
380        let mut pending_fns: Vec<FnDecl> = Vec::new();
381        for decl in decls {
382            if let Decl::Fn(fd) = decl {
383                pending_fns.push(fd.clone());
384                continue;
385            }
386
387            if !pending_fns.is_empty() {
388                self.register_fn_decls(&pending_fns)?;
389                pending_fns.clear();
390            }
391
392            self.register_decl(decl)?;
393        }
394        if !pending_fns.is_empty() {
395            self.register_fn_decls(&pending_fns)?;
396        }
397        Ok(())
398    }
399
400    pub fn add_value(&mut self, name: impl AsRef<str>, scheme: Scheme) {
401        let name = Symbol::intern(name.as_ref());
402        self.declared_values.remove(&name);
403        self.env.extend(name, scheme);
404    }
405
406    pub fn add_overload(&mut self, name: impl AsRef<str>, scheme: Scheme) {
407        let name = Symbol::intern(name.as_ref());
408        self.declared_values.remove(&name);
409        self.env.extend_overload(name, scheme);
410    }
411
412    pub fn register_instance(&mut self, class: impl AsRef<str>, inst: Instance) {
413        self.classes
414            .add_instance(Symbol::intern(class.as_ref()), inst);
415    }
416
417    pub fn register_class_decl(&mut self, decl: &ClassDecl) -> Result<(), TypeError> {
418        let span = decl.span;
419        (|| {
420            // Classes are global, and Rex does not support reopening/merging them.
421            // Allowing that would be a long-term maintenance hazard: it creates
422            // spooky-action-at-a-distance across modules and makes reviews harder.
423            if self.class_info.contains_key(&decl.name)
424                || self.classes.classes.contains_key(&decl.name)
425            {
426                return Err(TypeError::DuplicateClass(decl.name.clone()));
427            }
428            if decl.params.is_empty() {
429                return Err(TypeError::InvalidClassArity {
430                    class: decl.name.clone(),
431                    got: decl.params.len(),
432                });
433            }
434            let params = decl.params.clone();
435
436            // Register the superclass relationships in the class environment.
437            //
438            // We only accept `<= C param` style superclasses for now. Anything
439            // fancier would require storing type-level relationships in `ClassEnv`,
440            // which Rex does not currently model.
441            let mut supers = Vec::with_capacity(decl.supers.len());
442            if !decl.supers.is_empty() && params.len() != 1 {
443                return Err(TypeError::UnsupportedExpr(
444                    "multi-parameter classes cannot declare superclasses yet",
445                ));
446            }
447            for sup in &decl.supers {
448                let mut vars = BTreeMap::new();
449                let param = params[0].clone();
450                let param_tv = self.supply.fresh(Some(param.clone()));
451                vars.insert(param, param_tv.clone());
452                let sup_ty = type_from_annotation_expr_vars(
453                    &self.adts,
454                    &sup.typ,
455                    &mut vars,
456                    &mut self.supply,
457                )?;
458                if sup_ty != Type::var(param_tv) {
459                    return Err(TypeError::UnsupportedExpr(
460                        "superclass constraints must be of the form `<= C a`",
461                    ));
462                }
463                supers.push(sup.class.to_dotted_symbol());
464            }
465
466            self.classes.add_class(decl.name.clone(), supers.clone());
467
468            let mut methods = BTreeMap::new();
469            for ClassMethodSig { name, typ } in &decl.methods {
470                if self.env.lookup(name).is_some() || self.class_methods.contains_key(name) {
471                    return Err(TypeError::DuplicateClassMethod(name.clone()));
472                }
473
474                let mut vars: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
475                let mut param_tvs: Vec<TypeVar> = Vec::with_capacity(params.len());
476                for param in &params {
477                    let tv = self.supply.fresh(Some(param.clone()));
478                    vars.insert(param.clone(), tv.clone());
479                    param_tvs.push(tv);
480                }
481
482                let ty =
483                    type_from_annotation_expr_vars(&self.adts, typ, &mut vars, &mut self.supply)?;
484
485                let mut scheme_vars: Vec<TypeVar> = vars.values().cloned().collect();
486                scheme_vars.sort_by_key(|tv| tv.id);
487                scheme_vars.dedup_by_key(|tv| tv.id);
488
489                let class_pred = Predicate {
490                    class: decl.name.clone(),
491                    typ: if param_tvs.len() == 1 {
492                        Type::var(param_tvs[0].clone())
493                    } else {
494                        Type::tuple(param_tvs.into_iter().map(Type::var).collect())
495                    },
496                };
497                let scheme = Scheme::new(scheme_vars, vec![class_pred], ty);
498
499                self.env.extend(name.clone(), scheme.clone());
500                self.class_methods.insert(
501                    name.clone(),
502                    ClassMethodInfo {
503                        class: decl.name.clone(),
504                        scheme: scheme.clone(),
505                    },
506                );
507                methods.insert(name.clone(), scheme);
508            }
509
510            self.class_info.insert(
511                decl.name.clone(),
512                ClassInfo {
513                    name: decl.name.clone(),
514                    params,
515                    supers,
516                    methods,
517                },
518            );
519            Ok(())
520        })()
521        .map_err(|err| err.with_span(&span))
522    }
523
524    pub fn register_instance_decl(
525        &mut self,
526        decl: &InstanceDecl,
527    ) -> Result<PreparedInstanceDecl, TypeError> {
528        let span = decl.span;
529        (|| {
530            let class = decl.class.clone();
531            if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
532                return Err(TypeError::UnknownClass(class));
533            }
534
535            let mut vars: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
536            let head = type_from_annotation_expr_vars(
537                &self.adts,
538                &decl.head,
539                &mut vars,
540                &mut self.supply,
541            )?;
542            let context = predicates_from_constraints(
543                &self.adts,
544                &decl.context,
545                &mut vars,
546                &mut self.supply,
547            )?;
548
549            let inst = Instance::new(
550                context.clone(),
551                Predicate {
552                    class: decl.class.clone(),
553                    typ: head.clone(),
554                },
555            );
556
557            // Validate method list against the class declaration if present.
558            if let Some(info) = self.class_info.get(&decl.class) {
559                for method in &decl.methods {
560                    if !info.methods.contains_key(&method.name) {
561                        return Err(TypeError::UnknownInstanceMethod {
562                            class: decl.class.clone(),
563                            method: method.name.clone(),
564                        });
565                    }
566                }
567                for method_name in info.methods.keys() {
568                    if !decl.methods.iter().any(|m| &m.name == method_name) {
569                        return Err(TypeError::MissingInstanceMethod {
570                            class: decl.class.clone(),
571                            method: method_name.clone(),
572                        });
573                    }
574                }
575            }
576
577            self.classes.add_instance(decl.class.clone(), inst);
578            Ok(PreparedInstanceDecl {
579                span,
580                class: decl.class.clone(),
581                head,
582                context,
583            })
584        })()
585        .map_err(|err| err.with_span(&span))
586    }
587
588    pub fn prepare_instance_decl(
589        &mut self,
590        decl: &InstanceDecl,
591    ) -> Result<PreparedInstanceDecl, TypeError> {
592        let span = decl.span;
593        (|| {
594            let class = decl.class.clone();
595            if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
596                return Err(TypeError::UnknownClass(class));
597            }
598
599            let mut vars: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
600            let head = type_from_annotation_expr_vars(
601                &self.adts,
602                &decl.head,
603                &mut vars,
604                &mut self.supply,
605            )?;
606            let context = predicates_from_constraints(
607                &self.adts,
608                &decl.context,
609                &mut vars,
610                &mut self.supply,
611            )?;
612
613            // Validate method list against the class declaration if present.
614            if let Some(info) = self.class_info.get(&decl.class) {
615                for method in &decl.methods {
616                    if !info.methods.contains_key(&method.name) {
617                        return Err(TypeError::UnknownInstanceMethod {
618                            class: decl.class.clone(),
619                            method: method.name.clone(),
620                        });
621                    }
622                }
623                for method_name in info.methods.keys() {
624                    if !decl.methods.iter().any(|m| &m.name == method_name) {
625                        return Err(TypeError::MissingInstanceMethod {
626                            class: decl.class.clone(),
627                            method: method_name.clone(),
628                        });
629                    }
630                }
631            }
632
633            Ok(PreparedInstanceDecl {
634                span,
635                class: decl.class.clone(),
636                head,
637                context,
638            })
639        })()
640        .map_err(|err| err.with_span(&span))
641    }
642
643    pub fn register_fn_decls(&mut self, decls: &[FnDecl]) -> Result<(), TypeError> {
644        if decls.is_empty() {
645            return Ok(());
646        }
647
648        let saved_env = self.env.clone();
649        let saved_declared = self.declared_values.clone();
650
651        let result: Result<(), TypeError> = (|| {
652            #[derive(Clone)]
653            struct FnInfo {
654                decl: FnDecl,
655                expected: Type,
656                declared_preds: Vec<Predicate>,
657                scheme: Scheme,
658                ann_vars: BTreeMap<Symbol, TypeVar>,
659            }
660
661            let mut infos: Vec<FnInfo> = Vec::with_capacity(decls.len());
662            let mut seen_names = BTreeSet::new();
663
664            for decl in decls {
665                let span = decl.span;
666                let info = (|| {
667                    let name = &decl.name.name;
668                    if !seen_names.insert(name.clone()) {
669                        return Err(TypeError::DuplicateValue(name.clone()));
670                    }
671
672                    if self.env.lookup(name).is_some() {
673                        if self.declared_values.remove(name) {
674                            // A forward declaration should not block the real definition.
675                            self.env.remove(name);
676                        } else {
677                            return Err(TypeError::DuplicateValue(name.clone()));
678                        }
679                    }
680
681                    let mut sig = decl.ret.clone();
682                    for (_, ann) in decl.params.iter().rev() {
683                        let span = Span::from_begin_end(ann.span().begin, sig.span().end);
684                        sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
685                    }
686
687                    let mut ann_vars: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
688                    let expected = type_from_annotation_expr_vars(
689                        &self.adts,
690                        &sig,
691                        &mut ann_vars,
692                        &mut self.supply,
693                    )?;
694                    let declared_preds = predicates_from_constraints(
695                        &self.adts,
696                        &decl.constraints,
697                        &mut ann_vars,
698                        &mut self.supply,
699                    )?;
700
701                    // Validate that declared constraints are well-formed.
702                    let var_arities: BTreeMap<TypeVarId, usize> = ann_vars
703                        .values()
704                        .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
705                        .collect();
706                    for pred in &declared_preds {
707                        let _ = entails(&self.classes, &[], pred)?;
708                        let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
709                        else {
710                            continue;
711                        };
712                        let args: Vec<Type> = if expected_arities.len() == 1 {
713                            vec![pred.typ.clone()]
714                        } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
715                            if parts.len() != expected_arities.len() {
716                                continue;
717                            }
718                            parts.clone()
719                        } else {
720                            continue;
721                        };
722
723                        for (arg, expected_arity) in
724                            args.iter().zip(expected_arities.iter().copied())
725                        {
726                            let got =
727                                type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
728                                    TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
729                                    _ => None,
730                                });
731                            let Some(got) = got else {
732                                continue;
733                            };
734                            if got != expected_arity {
735                                return Err(TypeError::KindMismatch {
736                                    class: pred.class.clone(),
737                                    expected: expected_arity,
738                                    got,
739                                    typ: arg.to_string(),
740                                });
741                            }
742                        }
743                    }
744
745                    let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
746                    vars.sort_by_key(|v| v.id);
747                    let scheme = Scheme::new(vars, declared_preds.clone(), expected.clone());
748                    reject_ambiguous_scheme(&scheme)?;
749
750                    Ok(FnInfo {
751                        decl: decl.clone(),
752                        expected,
753                        declared_preds,
754                        scheme,
755                        ann_vars,
756                    })
757                })();
758
759                infos.push(info.map_err(|err| err.with_span(&span))?);
760            }
761
762            // Seed environment with all declared signatures first so fn bodies
763            // can reference each other recursively (let-rec semantics).
764            for info in &infos {
765                self.env
766                    .extend(info.decl.name.name.clone(), info.scheme.clone());
767            }
768
769            for info in infos {
770                let span = info.decl.span;
771                let mut lam_body = info.decl.body.clone();
772                let mut lam_end = lam_body.span().end;
773                for (param, ann) in info.decl.params.iter().rev() {
774                    let lam_constraints = Vec::new();
775                    let span = Span::from_begin_end(param.span.begin, lam_end);
776                    lam_body = Arc::new(Expr::Lam(
777                        span,
778                        Scope::new_sync(),
779                        param.clone(),
780                        Some(ann.clone()),
781                        lam_constraints,
782                        lam_body,
783                    ));
784                    lam_end = lam_body.span().end;
785                }
786
787                // Declaration validation happens before top-level `fn`s are lowered into
788                // synthetic `let`s. If one of these errors escapes without a span, later LSP
789                // fallbacks can attach it to the lowered expression, which may span the whole
790                // snippet. Use the original declaration span as the fallback; `with_span`
791                // preserves any more-specific span already produced inside the body.
792                let (typed, preds, inferred) =
793                    infer_typed(self, lam_body.as_ref()).map_err(|err| err.with_span(&span))?;
794                let s = unify(&inferred, &info.expected).map_err(|err| err.with_span(&span))?;
795                let preds = preds.apply(&s);
796                let inferred = inferred.apply(&s);
797                let declared_preds = info.declared_preds.apply(&s);
798                let expected = info.expected.apply(&s);
799
800                // Keep kind checks aligned with existing `inject_fn_decl` logic.
801                let var_arities: BTreeMap<TypeVarId, usize> = info
802                    .ann_vars
803                    .values()
804                    .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
805                    .collect();
806                for pred in &declared_preds {
807                    let _ =
808                        entails(&self.classes, &[], pred).map_err(|err| err.with_span(&span))?;
809                    let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
810                    else {
811                        continue;
812                    };
813                    let args: Vec<Type> = if expected_arities.len() == 1 {
814                        vec![pred.typ.clone()]
815                    } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
816                        if parts.len() != expected_arities.len() {
817                            continue;
818                        }
819                        parts.clone()
820                    } else {
821                        continue;
822                    };
823
824                    for (arg, expected_arity) in args.iter().zip(expected_arities.iter().copied()) {
825                        let got = type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
826                            TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
827                            _ => None,
828                        });
829                        let Some(got) = got else {
830                            continue;
831                        };
832                        if got != expected_arity {
833                            let err = TypeError::KindMismatch {
834                                class: pred.class.clone(),
835                                expected: expected_arity,
836                                got,
837                                typ: arg.to_string(),
838                            };
839                            return Err(err.with_span(&span));
840                        }
841                    }
842                }
843
844                check_non_ground_predicates_declared(&self.classes, &declared_preds, &preds)
845                    .map_err(|err| err.with_span(&span))?;
846
847                let _ = inferred;
848                let _ = typed;
849            }
850
851            Ok(())
852        })();
853
854        if result.is_err() {
855            self.env = saved_env;
856            self.declared_values = saved_declared;
857        }
858        result
859    }
860
861    pub fn inject_declare_fn_decl(&mut self, decl: &DeclareFnDecl) -> Result<(), TypeError> {
862        let span = decl.span;
863        (|| {
864            // Build the declared signature type.
865            let mut sig = decl.ret.clone();
866            for (_, ann) in decl.params.iter().rev() {
867                let span = Span::from_begin_end(ann.span().begin, sig.span().end);
868                sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
869            }
870
871            let mut ann_vars: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
872            let expected =
873                type_from_annotation_expr_vars(&self.adts, &sig, &mut ann_vars, &mut self.supply)?;
874            let declared_preds = predicates_from_constraints(
875                &self.adts,
876                &decl.constraints,
877                &mut ann_vars,
878                &mut self.supply,
879            )?;
880
881            let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
882            vars.sort_by_key(|v| v.id);
883            let scheme = Scheme::new(vars, declared_preds, expected);
884            reject_ambiguous_scheme(&scheme)?;
885
886            // Validate referenced classes exist (and are spelled correctly).
887            for pred in &scheme.preds {
888                let _ = entails(&self.classes, &[], pred)?;
889            }
890
891            let name = &decl.name.name;
892
893            // If there is already a real definition (prelude/host/`fn`), treat
894            // `declare fn` as documentation only and ignore it.
895            if self.env.lookup(name).is_some() && !self.declared_values.contains(name) {
896                return Ok(());
897            }
898
899            if let Some(existing) = self.env.lookup(name) {
900                if existing.iter().any(|s| scheme_compatible(s, &scheme)) {
901                    return Ok(());
902                }
903                return Err(TypeError::DuplicateValue(decl.name.name.clone()));
904            }
905
906            self.env.extend(decl.name.name.clone(), scheme);
907            self.declared_values.insert(decl.name.name.clone());
908            Ok(())
909        })()
910        .map_err(|err| err.with_span(&span))
911    }
912
913    pub fn instantiate_class_method_for_head(
914        &mut self,
915        class: &Symbol,
916        method: &Symbol,
917        head: &Type,
918    ) -> Result<Type, TypeError> {
919        let info = self
920            .class_info
921            .get(class)
922            .ok_or_else(|| TypeError::UnknownClass(class.clone()))?;
923        let scheme = info
924            .methods
925            .get(method)
926            .ok_or_else(|| TypeError::UnknownInstanceMethod {
927                class: class.clone(),
928                method: method.clone(),
929            })?;
930
931        let (preds, typ) = instantiate(scheme, &mut self.supply);
932        let class_pred =
933            preds
934                .iter()
935                .find(|p| &p.class == class)
936                .ok_or(TypeError::UnsupportedExpr(
937                    "class method scheme missing class predicate",
938                ))?;
939        let s = unify(&class_pred.typ, head)?;
940        Ok(typ.apply(&s))
941    }
942
943    pub fn typecheck_instance_method(
944        &mut self,
945        prepared: &PreparedInstanceDecl,
946        method: &InstanceMethodImpl,
947    ) -> Result<TypedExpr, TypeError> {
948        let expected =
949            self.instantiate_class_method_for_head(&prepared.class, &method.name, &prepared.head)?;
950        let (typed, preds, actual) = infer_typed(self, method.body.as_ref())?;
951        let s = unify(&actual, &expected)?;
952        let typed = typed.apply(&s);
953        let preds = preds.apply(&s);
954
955        // The only legal “given” constraints inside an instance method are the
956        // instance context (plus superclass closure, plus the instance head
957        // itself). We do *not* allow instance
958        // search for non-ground constraints here, because that would be unsound:
959        // a type variable would unify with any concrete instance head.
960        let mut given = prepared.context.clone();
961
962        // Allow recursive instance methods (e.g. `Eq (List a)` calling `(==)`
963        // on the tail). This is dictionary recursion, not instance search.
964        given.push(Predicate::new(
965            prepared.class.clone(),
966            prepared.head.clone(),
967        ));
968        let mut i = 0;
969        while i < given.len() {
970            let p = given[i].clone();
971            for sup in self.classes.supers_of(&p.class) {
972                given.push(Predicate::new(sup, p.typ.clone()));
973            }
974            i += 1;
975        }
976
977        for pred in &preds {
978            if pred.typ.ftv().is_empty() {
979                if !entails(&self.classes, &given, pred)? {
980                    return Err(TypeError::NoInstance(
981                        pred.class.clone(),
982                        pred.typ.to_string(),
983                    ));
984                }
985            } else if !given
986                .iter()
987                .any(|p| p.class == pred.class && p.typ == pred.typ)
988            {
989                return Err(TypeError::MissingInstanceConstraint {
990                    method: method.name.clone(),
991                    class: pred.class.clone(),
992                    typ: pred.typ.to_string(),
993                });
994            }
995        }
996
997        Ok(typed)
998    }
999
1000    /// Register constructor schemes for an ADT in the type environment.
1001    /// This makes constructors (e.g. `Some`, `None`, `Ok`, `Err`) available
1002    /// to the type checker as normal values.
1003    pub fn register_adt(&mut self, adt: &AdtDecl) {
1004        self.adts.insert(adt.name.clone(), adt.clone());
1005        for (name, scheme) in adt.constructor_schemes() {
1006            self.register_value_scheme(&name, scheme);
1007        }
1008    }
1009
1010    pub fn adt_from_decl(&mut self, decl: &TypeDecl) -> Result<AdtDecl, TypeError> {
1011        let mut adt = AdtDecl::new(&decl.name, &decl.params, &mut self.supply);
1012        let mut param_map: BTreeMap<Symbol, TypeVar> = BTreeMap::new();
1013        for param in &adt.params {
1014            param_map.insert(param.name.clone(), param.var.clone());
1015        }
1016
1017        for variant in &decl.variants {
1018            let mut args = Vec::new();
1019            for arg in &variant.args {
1020                let ty = self.type_from_expr(decl, &param_map, arg)?;
1021                args.push(ty);
1022            }
1023            adt.add_variant(variant.name.clone(), args);
1024        }
1025        Ok(adt)
1026    }
1027
1028    pub fn register_type_decl(&mut self, decl: &TypeDecl) -> Result<(), TypeError> {
1029        if BuiltinTypeId::from_symbol(&decl.name).is_some() {
1030            return Err(TypeError::ReservedTypeName(decl.name.clone()));
1031        }
1032        let adt = self.adt_from_decl(decl)?;
1033        self.register_adt(&adt);
1034        Ok(())
1035    }
1036
1037    fn type_from_expr(
1038        &mut self,
1039        decl: &TypeDecl,
1040        params: &BTreeMap<Symbol, TypeVar>,
1041        expr: &TypeExpr,
1042    ) -> Result<Type, TypeError> {
1043        let span = *expr.span();
1044        let res = (|| match expr {
1045            TypeExpr::Name(_, name) => {
1046                let name_sym = name.to_dotted_symbol();
1047                if let Some(tv) = params.get(&name_sym) {
1048                    Ok(Type::var(tv.clone()))
1049                } else {
1050                    let name = normalize_type_name(&name_sym);
1051                    if let Some(arity) = self.type_arity(decl, &name) {
1052                        Ok(Type::con(name, arity))
1053                    } else {
1054                        Err(TypeError::UnknownTypeName(name))
1055                    }
1056                }
1057            }
1058            TypeExpr::App(_, fun, arg) => {
1059                let fty = self.type_from_expr(decl, params, fun)?;
1060                let aty = self.type_from_expr(decl, params, arg)?;
1061                Ok(type_app_with_result_syntax(fty, aty))
1062            }
1063            TypeExpr::Fun(_, arg, ret) => {
1064                let arg_ty = self.type_from_expr(decl, params, arg)?;
1065                let ret_ty = self.type_from_expr(decl, params, ret)?;
1066                Ok(Type::fun(arg_ty, ret_ty))
1067            }
1068            TypeExpr::Tuple(_, elems) => {
1069                let mut out = Vec::new();
1070                for elem in elems {
1071                    out.push(self.type_from_expr(decl, params, elem)?);
1072                }
1073                Ok(Type::tuple(out))
1074            }
1075            TypeExpr::Record(_, fields) => {
1076                let mut out = Vec::new();
1077                for (name, ty) in fields {
1078                    out.push((name.clone(), self.type_from_expr(decl, params, ty)?));
1079                }
1080                Ok(Type::record(out))
1081            }
1082        })();
1083        res.map_err(|err| err.with_span(&span))
1084    }
1085
1086    fn type_arity(&self, decl: &TypeDecl, name: &Symbol) -> Option<usize> {
1087        if &decl.name == name {
1088            return Some(decl.params.len());
1089        }
1090        if let Some(adt) = self.adts.get(name) {
1091            return Some(adt.params.len());
1092        }
1093        BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
1094    }
1095
1096    fn register_value_scheme(&mut self, name: &Symbol, scheme: Scheme) {
1097        match self.env.lookup(name) {
1098            None => self.env.extend(name.clone(), scheme),
1099            Some(existing) => {
1100                if existing.iter().any(|s| unify(&s.typ, &scheme.typ).is_ok()) {
1101                    return;
1102                }
1103                self.env.extend_overload(name.clone(), scheme);
1104            }
1105        }
1106    }
1107
1108    fn expected_class_param_arities(&self, class: &Symbol) -> Option<Vec<usize>> {
1109        let info = self.class_info.get(class)?;
1110        let mut out = vec![0usize; info.params.len()];
1111        for scheme in info.methods.values() {
1112            for (idx, param) in info.params.iter().enumerate() {
1113                let Some(tv) = scheme.vars.iter().find(|v| v.name.as_ref() == Some(param)) else {
1114                    continue;
1115                };
1116                out[idx] = out[idx].max(max_head_app_arity_for_var(&scheme.typ, tv.id));
1117            }
1118        }
1119        Some(out)
1120    }
1121
1122    fn check_predicate_kind(&self, pred: &Predicate) -> Result<(), TypeError> {
1123        let Some(expected) = self.expected_class_param_arities(&pred.class) else {
1124            // Host-injected classes (via Rust API) won't have `class_info`.
1125            return Ok(());
1126        };
1127
1128        let args: Vec<Type> = if expected.len() == 1 {
1129            vec![pred.typ.clone()]
1130        } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
1131            if parts.len() != expected.len() {
1132                return Ok(());
1133            }
1134            parts.clone()
1135        } else {
1136            return Ok(());
1137        };
1138
1139        for (arg, expected_arity) in args.iter().zip(expected.iter().copied()) {
1140            let Some(got) = type_term_remaining_arity(arg) else {
1141                // If we can't determine the arity (e.g. a bare type var), skip:
1142                // call sites may fix it up, and Rex does not currently do full
1143                // kind inference.
1144                continue;
1145            };
1146            if got != expected_arity {
1147                return Err(TypeError::KindMismatch {
1148                    class: pred.class.clone(),
1149                    expected: expected_arity,
1150                    got,
1151                    typ: arg.to_string(),
1152                });
1153            }
1154        }
1155        Ok(())
1156    }
1157
1158    pub(crate) fn check_predicate_kinds(&self, preds: &[Predicate]) -> Result<(), TypeError> {
1159        for pred in preds {
1160            self.check_predicate_kind(pred)?;
1161        }
1162        Ok(())
1163    }
1164}
1165
1166pub(crate) fn type_from_annotation_expr(
1167    adts: &BTreeMap<Symbol, AdtDecl>,
1168    expr: &TypeExpr,
1169) -> Result<Type, TypeError> {
1170    let span = *expr.span();
1171    let res = (|| match expr {
1172        TypeExpr::Name(_, name) => {
1173            let name = normalize_type_name(&name.to_dotted_symbol());
1174            match annotation_type_arity(adts, &name) {
1175                Some(arity) => Ok(Type::con(name, arity)),
1176                None => Err(TypeError::UnknownTypeName(name)),
1177            }
1178        }
1179        TypeExpr::App(_, fun, arg) => {
1180            let fty = type_from_annotation_expr(adts, fun)?;
1181            let aty = type_from_annotation_expr(adts, arg)?;
1182            Ok(type_app_with_result_syntax(fty, aty))
1183        }
1184        TypeExpr::Fun(_, arg, ret) => {
1185            let arg_ty = type_from_annotation_expr(adts, arg)?;
1186            let ret_ty = type_from_annotation_expr(adts, ret)?;
1187            Ok(Type::fun(arg_ty, ret_ty))
1188        }
1189        TypeExpr::Tuple(_, elems) => {
1190            let mut out = Vec::new();
1191            for elem in elems {
1192                out.push(type_from_annotation_expr(adts, elem)?);
1193            }
1194            Ok(Type::tuple(out))
1195        }
1196        TypeExpr::Record(_, fields) => {
1197            let mut out = Vec::new();
1198            for (name, ty) in fields {
1199                out.push((name.clone(), type_from_annotation_expr(adts, ty)?));
1200            }
1201            Ok(Type::record(out))
1202        }
1203    })();
1204    res.map_err(|err| err.with_span(&span))
1205}
1206
1207pub(crate) fn type_from_annotation_expr_vars(
1208    adts: &BTreeMap<Symbol, AdtDecl>,
1209    expr: &TypeExpr,
1210    vars: &mut BTreeMap<Symbol, TypeVar>,
1211    supply: &mut TypeVarSupply,
1212) -> Result<Type, TypeError> {
1213    let span = *expr.span();
1214    let res = (|| match expr {
1215        TypeExpr::Name(_, name) => {
1216            let name = normalize_type_name(&name.to_dotted_symbol());
1217            if let Some(arity) = annotation_type_arity(adts, &name) {
1218                Ok(Type::con(name, arity))
1219            } else if let Some(tv) = vars.get(&name) {
1220                Ok(Type::var(tv.clone()))
1221            } else {
1222                let is_upper = name
1223                    .as_ref()
1224                    .chars()
1225                    .next()
1226                    .map(|c| c.is_uppercase())
1227                    .unwrap_or(false);
1228                if is_upper {
1229                    return Err(TypeError::UnknownTypeName(name));
1230                }
1231                let tv = supply.fresh(Some(name.clone()));
1232                vars.insert(name.clone(), tv.clone());
1233                Ok(Type::var(tv))
1234            }
1235        }
1236        TypeExpr::App(_, fun, arg) => {
1237            let fty = type_from_annotation_expr_vars(adts, fun, vars, supply)?;
1238            let aty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
1239            Ok(type_app_with_result_syntax(fty, aty))
1240        }
1241        TypeExpr::Fun(_, arg, ret) => {
1242            let arg_ty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
1243            let ret_ty = type_from_annotation_expr_vars(adts, ret, vars, supply)?;
1244            Ok(Type::fun(arg_ty, ret_ty))
1245        }
1246        TypeExpr::Tuple(_, elems) => {
1247            let mut out = Vec::new();
1248            for elem in elems {
1249                out.push(type_from_annotation_expr_vars(adts, elem, vars, supply)?);
1250            }
1251            Ok(Type::tuple(out))
1252        }
1253        TypeExpr::Record(_, fields) => {
1254            let mut out = Vec::new();
1255            for (name, ty) in fields {
1256                out.push((
1257                    name.clone(),
1258                    type_from_annotation_expr_vars(adts, ty, vars, supply)?,
1259                ));
1260            }
1261            Ok(Type::record(out))
1262        }
1263    })();
1264    res.map_err(|err| err.with_span(&span))
1265}
1266
1267fn annotation_type_arity(adts: &BTreeMap<Symbol, AdtDecl>, name: &Symbol) -> Option<usize> {
1268    if let Some(adt) = adts.get(name) {
1269        return Some(adt.params.len());
1270    }
1271    BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
1272}
1273
1274fn normalize_type_name(name: &Symbol) -> Symbol {
1275    if name.as_ref() == "str" {
1276        BuiltinTypeId::String.as_symbol()
1277    } else {
1278        name.clone()
1279    }
1280}
1281
1282fn type_app_with_result_syntax(fun: Type, arg: Type) -> Type {
1283    if let TypeKind::App(head, ok) = fun.as_ref()
1284        && matches!(
1285            head.as_ref(),
1286            TypeKind::Con(c)
1287                if c.is_builtin(BuiltinTypeId::Result) && c.arity() == 2
1288        )
1289    {
1290        return Type::app(Type::app(head.clone(), arg), ok.clone());
1291    }
1292    Type::app(fun, arg)
1293}
1294
1295pub(crate) fn predicates_from_constraints(
1296    adts: &BTreeMap<Symbol, AdtDecl>,
1297    constraints: &[TypeConstraint],
1298    vars: &mut BTreeMap<Symbol, TypeVar>,
1299    supply: &mut TypeVarSupply,
1300) -> Result<Vec<Predicate>, TypeError> {
1301    let mut out = Vec::with_capacity(constraints.len());
1302    for constraint in constraints {
1303        let ty = type_from_annotation_expr_vars(adts, &constraint.typ, vars, supply)?;
1304        out.push(Predicate::new(constraint.class.as_ref(), ty));
1305    }
1306    Ok(out)
1307}