1use std::collections::{BTreeMap, HashMap, HashSet};
4use std::fmt::{self, Display, Formatter};
5use std::sync::Arc;
6
7use chrono::{DateTime, Utc};
8use rexlang_ast::expr::{
9 ClassDecl, ClassMethodSig, Decl, DeclareFnDecl, Expr, FnDecl, InstanceDecl, InstanceMethodImpl,
10 Pattern, Scope, Symbol, TypeConstraint, TypeDecl, TypeExpr, intern, sym,
11};
12use rexlang_lexer::span::Span;
13use rexlang_util::{GasMeter, OutOfGas};
14use rpds::HashTrieMapSync;
15use uuid::Uuid;
16
17use crate::prelude;
18
19#[path = "inference.rs"]
20pub mod inference;
21
22pub use inference::{infer, infer_typed, infer_typed_with_gas, infer_with_gas};
23
24pub type TypeVarId = usize;
25
26#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
27pub enum BuiltinTypeId {
28 U8,
29 U16,
30 U32,
31 U64,
32 I8,
33 I16,
34 I32,
35 I64,
36 F32,
37 F64,
38 Bool,
39 String,
40 Uuid,
41 DateTime,
42 List,
43 Array,
44 Dict,
45 Option,
46 Result,
47}
48
49impl BuiltinTypeId {
50 pub fn as_symbol(self) -> Symbol {
51 sym(self.as_str())
52 }
53
54 pub fn as_str(self) -> &'static str {
55 match self {
56 Self::U8 => "u8",
57 Self::U16 => "u16",
58 Self::U32 => "u32",
59 Self::U64 => "u64",
60 Self::I8 => "i8",
61 Self::I16 => "i16",
62 Self::I32 => "i32",
63 Self::I64 => "i64",
64 Self::F32 => "f32",
65 Self::F64 => "f64",
66 Self::Bool => "bool",
67 Self::String => "string",
68 Self::Uuid => "uuid",
69 Self::DateTime => "datetime",
70 Self::List => "List",
71 Self::Array => "Array",
72 Self::Dict => "Dict",
73 Self::Option => "Option",
74 Self::Result => "Result",
75 }
76 }
77
78 pub fn arity(self) -> usize {
79 match self {
80 Self::List | Self::Array | Self::Dict | Self::Option => 1,
81 Self::Result => 2,
82 _ => 0,
83 }
84 }
85
86 pub fn from_symbol(name: &Symbol) -> Option<Self> {
87 Self::from_name(name.as_ref())
88 }
89
90 pub fn from_name(name: &str) -> Option<Self> {
91 match name {
92 "u8" => Some(Self::U8),
93 "u16" => Some(Self::U16),
94 "u32" => Some(Self::U32),
95 "u64" => Some(Self::U64),
96 "i8" => Some(Self::I8),
97 "i16" => Some(Self::I16),
98 "i32" => Some(Self::I32),
99 "i64" => Some(Self::I64),
100 "f32" => Some(Self::F32),
101 "f64" => Some(Self::F64),
102 "bool" => Some(Self::Bool),
103 "string" => Some(Self::String),
104 "uuid" => Some(Self::Uuid),
105 "datetime" => Some(Self::DateTime),
106 "List" => Some(Self::List),
107 "Array" => Some(Self::Array),
108 "Dict" => Some(Self::Dict),
109 "Option" => Some(Self::Option),
110 "Result" => Some(Self::Result),
111 _ => None,
112 }
113 }
114}
115
116#[derive(Clone, Debug, Eq, Hash, PartialEq)]
117pub struct TypeVar {
118 pub id: TypeVarId,
119 pub name: Option<Symbol>,
120}
121
122impl TypeVar {
123 pub fn new(id: TypeVarId, name: impl Into<Option<Symbol>>) -> Self {
124 Self {
125 id,
126 name: name.into(),
127 }
128 }
129}
130
131#[derive(Clone, Debug, Eq, Hash, PartialEq)]
132pub struct TypeConst {
133 pub name: Symbol,
134 pub arity: usize,
135 pub builtin_id: Option<BuiltinTypeId>,
136}
137
138#[derive(Clone, Debug, PartialEq, Eq, Hash)]
139pub struct Type(Arc<TypeKind>);
140
141#[derive(Clone, Debug, PartialEq, Eq, Hash)]
142pub enum TypeKind {
143 Var(TypeVar),
144 Con(TypeConst),
145 App(Type, Type),
146 Fun(Type, Type),
147 Tuple(Vec<Type>),
148 Record(Vec<(Symbol, Type)>),
153}
154
155impl Type {
156 pub fn new(kind: TypeKind) -> Self {
157 Type(Arc::new(kind))
158 }
159
160 pub fn con(name: impl AsRef<str>, arity: usize) -> Self {
161 if let Some(id) = BuiltinTypeId::from_name(name.as_ref())
162 && id.arity() == arity
163 {
164 return Self::builtin(id);
165 }
166 Self::user_con(name, arity)
167 }
168
169 pub fn user_con(name: impl AsRef<str>, arity: usize) -> Self {
170 Type::new(TypeKind::Con(TypeConst {
171 name: intern(name.as_ref()),
172 arity,
173 builtin_id: None,
174 }))
175 }
176
177 pub fn builtin(id: BuiltinTypeId) -> Self {
178 Type::new(TypeKind::Con(TypeConst {
179 name: id.as_symbol(),
180 arity: id.arity(),
181 builtin_id: Some(id),
182 }))
183 }
184
185 pub fn var(tv: TypeVar) -> Self {
186 Type::new(TypeKind::Var(tv))
187 }
188
189 pub fn fun(a: Type, b: Type) -> Self {
190 Type::new(TypeKind::Fun(a, b))
191 }
192
193 pub fn app(f: Type, arg: Type) -> Self {
194 Type::new(TypeKind::App(f, arg))
195 }
196
197 pub fn tuple(elems: Vec<Type>) -> Self {
198 Type::new(TypeKind::Tuple(elems))
199 }
200
201 pub fn record(mut fields: Vec<(Symbol, Type)>) -> Self {
202 fields.sort_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
205 Type::new(TypeKind::Record(fields))
206 }
207
208 pub fn list(elem: Type) -> Type {
209 Type::app(Type::builtin(BuiltinTypeId::List), elem)
210 }
211
212 pub fn array(elem: Type) -> Type {
213 Type::app(Type::builtin(BuiltinTypeId::Array), elem)
214 }
215
216 pub fn dict(elem: Type) -> Type {
217 Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
218 }
219
220 pub fn option(elem: Type) -> Type {
221 Type::app(Type::builtin(BuiltinTypeId::Option), elem)
222 }
223
224 pub fn result(ok: Type, err: Type) -> Type {
225 Type::app(Type::app(Type::builtin(BuiltinTypeId::Result), err), ok)
226 }
227
228 fn apply_with_change(&self, s: &Subst) -> (Type, bool) {
229 match self.as_ref() {
230 TypeKind::Var(tv) => match s.get(&tv.id) {
231 Some(ty) => (ty.clone(), true),
232 None => (self.clone(), false),
233 },
234 TypeKind::Con(_) => (self.clone(), false),
235 TypeKind::App(l, r) => {
236 let (l_new, l_changed) = l.apply_with_change(s);
237 let (r_new, r_changed) = r.apply_with_change(s);
238 if l_changed || r_changed {
239 (Type::app(l_new, r_new), true)
240 } else {
241 (self.clone(), false)
242 }
243 }
244 TypeKind::Fun(_, _) => {
245 let mut args = Vec::new();
248 let mut changed = false;
249 let mut cur: &Type = self;
250 while let TypeKind::Fun(a, b) = cur.as_ref() {
251 let (a_new, a_changed) = a.apply_with_change(s);
252 changed |= a_changed;
253 args.push(a_new);
254 cur = b;
255 }
256 let (ret_new, ret_changed) = cur.apply_with_change(s);
257 changed |= ret_changed;
258 if !changed {
259 return (self.clone(), false);
260 }
261 let mut out = ret_new;
262 for a_new in args.into_iter().rev() {
263 out = Type::fun(a_new, out);
264 }
265 (out, true)
266 }
267 TypeKind::Tuple(ts) => {
268 let mut changed = false;
269 let mut out = Vec::with_capacity(ts.len());
270 for t in ts {
271 let (t_new, t_changed) = t.apply_with_change(s);
272 changed |= t_changed;
273 out.push(t_new);
274 }
275 if changed {
276 (Type::new(TypeKind::Tuple(out)), true)
277 } else {
278 (self.clone(), false)
279 }
280 }
281 TypeKind::Record(fields) => {
282 let mut changed = false;
283 let mut out = Vec::with_capacity(fields.len());
284 for (k, v) in fields {
285 let (v_new, v_changed) = v.apply_with_change(s);
286 changed |= v_changed;
287 out.push((k.clone(), v_new));
288 }
289 if changed {
290 (Type::new(TypeKind::Record(out)), true)
291 } else {
292 (self.clone(), false)
293 }
294 }
295 }
296 }
297}
298
299impl AsRef<TypeKind> for Type {
300 fn as_ref(&self) -> &TypeKind {
301 self.0.as_ref()
302 }
303}
304
305impl std::ops::Deref for Type {
306 type Target = TypeKind;
307
308 fn deref(&self) -> &Self::Target {
309 &self.0
310 }
311}
312
313impl Display for Type {
314 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
315 match self.as_ref() {
316 TypeKind::Var(tv) => match &tv.name {
317 Some(name) => write!(f, "'{}", name),
318 None => write!(f, "t{}", tv.id),
319 },
320 TypeKind::Con(c) => write!(f, "{}", c.name),
321 TypeKind::App(l, r) => {
322 if let TypeKind::App(head, err) = l.as_ref()
328 && matches!(
329 head.as_ref(),
330 TypeKind::Con(c)
331 if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
332 )
333 {
334 return write!(f, "(Result {} {})", r, err);
335 }
336 write!(f, "({} {})", l, r)
337 }
338 TypeKind::Fun(a, b) => write!(f, "({} -> {})", a, b),
339 TypeKind::Tuple(elems) => {
340 write!(f, "(")?;
341 for (i, t) in elems.iter().enumerate() {
342 write!(f, "{}", t)?;
343 if i + 1 < elems.len() {
344 write!(f, ", ")?;
345 }
346 }
347 write!(f, ")")
348 }
349 TypeKind::Record(fields) => {
350 write!(f, "{{")?;
351 for (i, (name, ty)) in fields.iter().enumerate() {
352 write!(f, "{}: {}", name, ty)?;
353 if i + 1 < fields.len() {
354 write!(f, ", ")?;
355 }
356 }
357 write!(f, "}}")
358 }
359 }
360 }
361}
362
363#[derive(Clone, Debug, PartialEq, Eq, Hash)]
364pub struct Predicate {
365 pub class: Symbol,
366 pub typ: Type,
367}
368
369impl Predicate {
370 pub fn new(class: impl AsRef<str>, typ: Type) -> Self {
371 Self {
372 class: intern(class.as_ref()),
373 typ,
374 }
375 }
376}
377
378#[derive(Clone, Debug, PartialEq)]
379pub struct Scheme {
380 pub vars: Vec<TypeVar>,
381 pub preds: Vec<Predicate>,
382 pub typ: Type,
383}
384
385impl Scheme {
386 pub fn new(vars: Vec<TypeVar>, preds: Vec<Predicate>, typ: Type) -> Self {
387 Self { vars, preds, typ }
388 }
389}
390
391pub type Subst = HashTrieMapSync<TypeVarId, Type>;
392
393pub trait Types: Sized {
394 fn apply(&self, s: &Subst) -> Self;
395 fn ftv(&self) -> HashSet<TypeVarId>;
396}
397
398impl Types for Type {
399 fn apply(&self, s: &Subst) -> Self {
400 self.apply_with_change(s).0
401 }
402
403 fn ftv(&self) -> HashSet<TypeVarId> {
404 let mut out = HashSet::new();
405 let mut stack: Vec<&Type> = vec![self];
406 while let Some(t) = stack.pop() {
407 match t.as_ref() {
408 TypeKind::Var(tv) => {
409 out.insert(tv.id);
410 }
411 TypeKind::Con(_) => {}
412 TypeKind::App(l, r) => {
413 stack.push(l);
414 stack.push(r);
415 }
416 TypeKind::Fun(a, b) => {
417 stack.push(a);
418 stack.push(b);
419 }
420 TypeKind::Tuple(ts) => {
421 for t in ts {
422 stack.push(t);
423 }
424 }
425 TypeKind::Record(fields) => {
426 for (_, ty) in fields {
427 stack.push(ty);
428 }
429 }
430 }
431 }
432 out
433 }
434}
435
436impl Types for Predicate {
437 fn apply(&self, s: &Subst) -> Self {
438 Predicate {
439 class: self.class.clone(),
440 typ: self.typ.apply(s),
441 }
442 }
443
444 fn ftv(&self) -> HashSet<TypeVarId> {
445 self.typ.ftv()
446 }
447}
448
449impl Types for Scheme {
450 fn apply(&self, s: &Subst) -> Self {
451 let mut s_pruned = Subst::new_sync();
452 for (k, v) in s.iter() {
453 if !self.vars.iter().any(|var| var.id == *k) {
454 s_pruned = s_pruned.insert(*k, v.clone());
455 }
456 }
457 Scheme::new(
458 self.vars.clone(),
459 self.preds.iter().map(|p| p.apply(&s_pruned)).collect(),
460 self.typ.apply(&s_pruned),
461 )
462 }
463
464 fn ftv(&self) -> HashSet<TypeVarId> {
465 let mut ftv = self.typ.ftv();
466 for p in &self.preds {
467 ftv.extend(p.ftv());
468 }
469 for v in &self.vars {
470 ftv.remove(&v.id);
471 }
472 ftv
473 }
474}
475
476impl<T: Types> Types for Vec<T> {
477 fn apply(&self, s: &Subst) -> Self {
478 self.iter().map(|t| t.apply(s)).collect()
479 }
480
481 fn ftv(&self) -> HashSet<TypeVarId> {
482 self.iter().flat_map(Types::ftv).collect()
483 }
484}
485
486#[derive(Clone, Debug, PartialEq)]
487pub struct TypedExpr {
488 pub typ: Type,
489 pub kind: TypedExprKind,
490}
491
492impl TypedExpr {
493 pub fn new(typ: Type, kind: TypedExprKind) -> Self {
494 Self { typ, kind }
495 }
496
497 pub fn apply(&self, s: &Subst) -> Self {
498 match &self.kind {
499 TypedExprKind::Lam { .. } => {
500 let mut params: Vec<(Symbol, Type)> = Vec::new();
501 let mut cur = self;
502 while let TypedExprKind::Lam { param, body } = &cur.kind {
503 params.push((param.clone(), cur.typ.apply(s)));
504 cur = body.as_ref();
505 }
506 let mut out = cur.apply(s);
507 for (param, typ) in params.into_iter().rev() {
508 out = TypedExpr {
509 typ,
510 kind: TypedExprKind::Lam {
511 param,
512 body: Box::new(out),
513 },
514 };
515 }
516 return out;
517 }
518 TypedExprKind::App(..) => {
519 let mut apps: Vec<(Type, &TypedExpr)> = Vec::new();
520 let mut cur = self;
521 while let TypedExprKind::App(f, x) = &cur.kind {
522 apps.push((cur.typ.apply(s), x.as_ref()));
523 cur = f.as_ref();
524 }
525 let mut out = cur.apply(s);
526 for (typ, arg) in apps.into_iter().rev() {
527 out = TypedExpr {
528 typ,
529 kind: TypedExprKind::App(Box::new(out), Box::new(arg.apply(s))),
530 };
531 }
532 return out;
533 }
534 _ => {}
535 }
536
537 let typ = self.typ.apply(s);
538 let kind = match &self.kind {
539 TypedExprKind::Bool(v) => TypedExprKind::Bool(*v),
540 TypedExprKind::Uint(v) => TypedExprKind::Uint(*v),
541 TypedExprKind::Int(v) => TypedExprKind::Int(*v),
542 TypedExprKind::Float(v) => TypedExprKind::Float(*v),
543 TypedExprKind::String(v) => TypedExprKind::String(v.clone()),
544 TypedExprKind::Uuid(v) => TypedExprKind::Uuid(*v),
545 TypedExprKind::DateTime(v) => TypedExprKind::DateTime(*v),
546 TypedExprKind::Hole => TypedExprKind::Hole,
547 TypedExprKind::Tuple(elems) => {
548 TypedExprKind::Tuple(elems.iter().map(|e| e.apply(s)).collect())
549 }
550 TypedExprKind::List(elems) => {
551 TypedExprKind::List(elems.iter().map(|e| e.apply(s)).collect())
552 }
553 TypedExprKind::Dict(kvs) => {
554 let mut out = BTreeMap::new();
555 for (k, v) in kvs {
556 out.insert(k.clone(), v.apply(s));
557 }
558 TypedExprKind::Dict(out)
559 }
560 TypedExprKind::RecordUpdate { base, updates } => {
561 let mut out = BTreeMap::new();
562 for (k, v) in updates {
563 out.insert(k.clone(), v.apply(s));
564 }
565 TypedExprKind::RecordUpdate {
566 base: Box::new(base.apply(s)),
567 updates: out,
568 }
569 }
570 TypedExprKind::Var { name, overloads } => TypedExprKind::Var {
571 name: name.clone(),
572 overloads: overloads.iter().map(|t| t.apply(s)).collect(),
573 },
574 TypedExprKind::App(f, x) => {
575 TypedExprKind::App(Box::new(f.apply(s)), Box::new(x.apply(s)))
576 }
577 TypedExprKind::Project { expr, field } => TypedExprKind::Project {
578 expr: Box::new(expr.apply(s)),
579 field: field.clone(),
580 },
581 TypedExprKind::Lam { param, body } => TypedExprKind::Lam {
582 param: param.clone(),
583 body: Box::new(body.apply(s)),
584 },
585 TypedExprKind::Let { name, def, body } => TypedExprKind::Let {
586 name: name.clone(),
587 def: Box::new(def.apply(s)),
588 body: Box::new(body.apply(s)),
589 },
590 TypedExprKind::LetRec { bindings, body } => TypedExprKind::LetRec {
591 bindings: bindings
592 .iter()
593 .map(|(name, def)| (name.clone(), def.apply(s)))
594 .collect(),
595 body: Box::new(body.apply(s)),
596 },
597 TypedExprKind::Ite {
598 cond,
599 then_expr,
600 else_expr,
601 } => TypedExprKind::Ite {
602 cond: Box::new(cond.apply(s)),
603 then_expr: Box::new(then_expr.apply(s)),
604 else_expr: Box::new(else_expr.apply(s)),
605 },
606 TypedExprKind::Match { scrutinee, arms } => TypedExprKind::Match {
607 scrutinee: Box::new(scrutinee.apply(s)),
608 arms: arms.iter().map(|(p, e)| (p.clone(), e.apply(s))).collect(),
609 },
610 };
611 TypedExpr { typ, kind }
612 }
613}
614
615#[derive(Clone, Debug, PartialEq)]
616pub enum TypedExprKind {
617 Bool(bool),
618 Uint(u64),
619 Int(i64),
620 Float(f64),
621 String(String),
622 Uuid(Uuid),
623 DateTime(DateTime<Utc>),
624 Hole,
625 Tuple(Vec<TypedExpr>),
626 List(Vec<TypedExpr>),
627 Dict(BTreeMap<Symbol, TypedExpr>),
628 RecordUpdate {
629 base: Box<TypedExpr>,
630 updates: BTreeMap<Symbol, TypedExpr>,
631 },
632 Var {
633 name: Symbol,
634 overloads: Vec<Type>,
635 },
636 App(Box<TypedExpr>, Box<TypedExpr>),
637 Project {
638 expr: Box<TypedExpr>,
639 field: Symbol,
640 },
641 Lam {
642 param: Symbol,
643 body: Box<TypedExpr>,
644 },
645 Let {
646 name: Symbol,
647 def: Box<TypedExpr>,
648 body: Box<TypedExpr>,
649 },
650 LetRec {
651 bindings: Vec<(Symbol, TypedExpr)>,
652 body: Box<TypedExpr>,
653 },
654 Ite {
655 cond: Box<TypedExpr>,
656 then_expr: Box<TypedExpr>,
657 else_expr: Box<TypedExpr>,
658 },
659 Match {
660 scrutinee: Box<TypedExpr>,
661 arms: Vec<(Pattern, TypedExpr)>,
662 },
663}
664
665pub fn compose_subst(a: Subst, b: Subst) -> Subst {
670 if subst_is_empty(&a) {
671 return b;
672 }
673 if subst_is_empty(&b) {
674 return a;
675 }
676 let mut res = Subst::new_sync();
677 for (k, v) in b.iter() {
678 res = res.insert(*k, v.apply(&a));
679 }
680 for (k, v) in a.iter() {
681 res = res.insert(*k, v.clone());
682 }
683 res
684}
685
686fn subst_is_empty(s: &Subst) -> bool {
687 s.iter().next().is_none()
688}
689
690#[derive(Debug, thiserror::Error, PartialEq, Eq)]
691pub enum TypeError {
692 #[error("types do not unify: {0} vs {1}")]
693 Unification(String, String),
694 #[error("occurs check failed for {0} in {1}")]
695 Occurs(TypeVarId, String),
696 #[error("unknown class {0}")]
697 UnknownClass(Symbol),
698 #[error("no instance for {0} {1}")]
699 NoInstance(Symbol, String),
700 #[error("unknown type {0}")]
701 UnknownTypeName(Symbol),
702 #[error("cannot redefine reserved builtin type `{0}`")]
703 ReservedTypeName(Symbol),
704 #[error("duplicate value definition `{0}`")]
705 DuplicateValue(Symbol),
706 #[error("duplicate class definition `{0}`")]
707 DuplicateClass(Symbol),
708 #[error("class `{class}` must have at least one type parameter (got {got})")]
709 InvalidClassArity { class: Symbol, got: usize },
710 #[error("duplicate class method `{0}`")]
711 DuplicateClassMethod(Symbol),
712 #[error("unknown method `{method}` in instance of class `{class}`")]
713 UnknownInstanceMethod { class: Symbol, method: Symbol },
714 #[error("missing implementation of `{method}` for instance of class `{class}`")]
715 MissingInstanceMethod { class: Symbol, method: Symbol },
716 #[error(
717 "instance method `{method}` requires constraint {class} {typ}, but it is not in the instance context"
718 )]
719 MissingInstanceConstraint {
720 method: Symbol,
721 class: Symbol,
722 typ: String,
723 },
724 #[error("unbound variable {0}")]
725 UnknownVar(Symbol),
726 #[error("ambiguous overload for {0}")]
727 AmbiguousOverload(Symbol),
728 #[error("ambiguous type variable(s) {vars:?} in constraints: {constraints}")]
729 AmbiguousTypeVars {
730 vars: Vec<TypeVarId>,
731 constraints: String,
732 },
733 #[error(
734 "kind mismatch for class `{class}`: expected {expected} type argument(s) remaining, got {got} for {typ}"
735 )]
736 KindMismatch {
737 class: Symbol,
738 expected: usize,
739 got: usize,
740 typ: String,
741 },
742 #[error("missing type class constraint(s): {constraints}")]
743 MissingConstraints { constraints: String },
744 #[error("unsupported expression {0}")]
745 UnsupportedExpr(&'static str),
746 #[error("unknown field `{field}` on {typ}")]
747 UnknownField { field: Symbol, typ: String },
748 #[error("field `{field}` is not definitely available on {typ}")]
749 FieldNotKnown { field: Symbol, typ: String },
750 #[error("non-exhaustive match for {typ}: missing {missing:?}")]
751 NonExhaustiveMatch { typ: String, missing: Vec<Symbol> },
752 #[error("at {span}: {error}")]
753 Spanned { span: Span, error: Box<TypeError> },
754 #[error("internal error: {0}")]
755 Internal(String),
756 #[error("{0}")]
757 OutOfGas(#[from] OutOfGas),
758}
759
760fn with_span(span: &Span, err: TypeError) -> TypeError {
761 match err {
762 TypeError::Spanned { .. } => err,
763 other => TypeError::Spanned {
764 span: *span,
765 error: Box::new(other),
766 },
767 }
768}
769
770fn format_constraints_referencing_vars(preds: &[Predicate], vars: &[TypeVarId]) -> String {
771 if vars.is_empty() {
772 return String::new();
773 }
774 let var_set: HashSet<TypeVarId> = vars.iter().copied().collect();
775 let mut parts = Vec::new();
776 for pred in preds {
777 let ftv = pred.ftv();
778 if ftv.iter().any(|v| var_set.contains(v)) {
779 parts.push(format!("{} {}", pred.class, pred.typ));
780 }
781 }
782 if parts.is_empty() {
783 for pred in preds {
785 parts.push(format!("{} {}", pred.class, pred.typ));
786 }
787 }
788 parts.join(", ")
789}
790
791fn reject_ambiguous_scheme(scheme: &Scheme) -> Result<(), TypeError> {
792 let quantified: HashSet<TypeVarId> = scheme.vars.iter().map(|v| v.id).collect();
796 if quantified.is_empty() {
797 return Ok(());
798 }
799
800 let typ_ftv = scheme.typ.ftv();
801 let mut vars = HashSet::new();
802 for pred in &scheme.preds {
803 let TypeKind::Var(tv) = pred.typ.as_ref() else {
804 continue;
805 };
806 if quantified.contains(&tv.id) && !typ_ftv.contains(&tv.id) {
807 vars.insert(tv.id);
808 }
809 }
810
811 if vars.is_empty() {
812 return Ok(());
813 }
814 let mut vars: Vec<TypeVarId> = vars.into_iter().collect();
815 vars.sort_unstable();
816 let constraints = format_constraints_referencing_vars(&scheme.preds, &vars);
817 Err(TypeError::AmbiguousTypeVars { vars, constraints })
818}
819
820fn scheme_compatible(existing: &Scheme, declared: &Scheme) -> bool {
821 let s = match unify(&existing.typ, &declared.typ) {
822 Ok(s) => s,
823 Err(_) => return false,
824 };
825
826 let existing_preds = existing.preds.apply(&s);
827 let declared_preds = declared.preds.apply(&s);
828
829 let mut lhs: Vec<(Symbol, String)> = existing_preds
830 .iter()
831 .map(|p| (p.class.clone(), p.typ.to_string()))
832 .collect();
833 let mut rhs: Vec<(Symbol, String)> = declared_preds
834 .iter()
835 .map(|p| (p.class.clone(), p.typ.to_string()))
836 .collect();
837 lhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
838 rhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
839 lhs == rhs
840}
841
842#[derive(Debug)]
843struct Unifier<'g> {
844 subs: Vec<Option<Type>>,
851 gas: Option<&'g mut GasMeter>,
852 max_infer_depth: Option<usize>,
853 infer_depth: usize,
854}
855
856#[derive(Clone, Copy, Debug)]
857pub struct TypeSystemLimits {
858 pub max_infer_depth: Option<usize>,
859}
860
861impl TypeSystemLimits {
862 pub fn unlimited() -> Self {
863 Self {
864 max_infer_depth: None,
865 }
866 }
867
868 pub fn safe_defaults() -> Self {
869 Self {
870 max_infer_depth: Some(4096),
871 }
872 }
873}
874
875impl Default for TypeSystemLimits {
876 fn default() -> Self {
877 Self::safe_defaults()
878 }
879}
880
881fn superclass_closure(class_env: &ClassEnv, given: &[Predicate]) -> Vec<Predicate> {
882 let mut closure: Vec<Predicate> = given.to_vec();
883 let mut i = 0;
884 while i < closure.len() {
885 let p = closure[i].clone();
886 for sup in class_env.supers_of(&p.class) {
887 closure.push(Predicate::new(sup, p.typ.clone()));
888 }
889 i += 1;
890 }
891 closure
892}
893
894fn check_non_ground_predicates_declared(
895 class_env: &ClassEnv,
896 declared: &[Predicate],
897 inferred: &[Predicate],
898) -> Result<(), TypeError> {
899 let closure = superclass_closure(class_env, declared);
903 let closure_keys: HashSet<String> = closure
904 .iter()
905 .map(|p| format!("{} {}", p.class, p.typ))
906 .collect();
907 let mut missing = Vec::new();
908 for pred in inferred {
909 if pred.typ.ftv().is_empty() {
910 continue;
911 }
912 let key = format!("{} {}", pred.class, pred.typ);
913 if !closure_keys.contains(&key) {
914 missing.push(key);
915 }
916 }
917
918 missing.sort();
919 missing.dedup();
920 if missing.is_empty() {
921 return Ok(());
922 }
923 Err(TypeError::MissingConstraints {
924 constraints: missing.join(", "),
925 })
926}
927
928fn type_term_remaining_arity(ty: &Type) -> Option<usize> {
929 match ty.as_ref() {
930 TypeKind::Var(_) => None,
931 TypeKind::Con(tc) => Some(tc.arity),
932 TypeKind::App(l, _) => {
933 let a = type_term_remaining_arity(l)?;
934 Some(a.saturating_sub(1))
935 }
936 TypeKind::Fun(..) | TypeKind::Tuple(..) | TypeKind::Record(..) => Some(0),
937 }
938}
939
940fn max_head_app_arity_for_var(ty: &Type, var_id: TypeVarId) -> usize {
941 let mut max_arity = 0usize;
942 let mut stack: Vec<&Type> = vec![ty];
943 while let Some(t) = stack.pop() {
944 match t.as_ref() {
945 TypeKind::Var(_) | TypeKind::Con(_) => {}
946 TypeKind::App(l, r) => {
947 let mut head = t;
949 let mut args = 0usize;
950 while let TypeKind::App(left, _) = head.as_ref() {
951 args += 1;
952 head = left;
953 }
954 if let TypeKind::Var(tv) = head.as_ref()
955 && tv.id == var_id
956 {
957 max_arity = max_arity.max(args);
958 }
959 stack.push(l);
960 stack.push(r);
961 }
962 TypeKind::Fun(a, b) => {
963 stack.push(a);
964 stack.push(b);
965 }
966 TypeKind::Tuple(ts) => {
967 for t in ts {
968 stack.push(t);
969 }
970 }
971 TypeKind::Record(fields) => {
972 for (_, t) in fields {
973 stack.push(t);
974 }
975 }
976 }
977 }
978 max_arity
979}
980
981impl<'g> Unifier<'g> {
982 fn new(max_infer_depth: Option<usize>) -> Self {
983 Self {
984 subs: Vec::new(),
985 gas: None,
986 max_infer_depth,
987 infer_depth: 0,
988 }
989 }
990
991 fn with_gas(gas: &'g mut GasMeter, max_infer_depth: Option<usize>) -> Self {
992 Self {
993 subs: Vec::new(),
994 gas: Some(gas),
995 max_infer_depth,
996 infer_depth: 0,
997 }
998 }
999
1000 fn with_infer_depth<T>(
1001 &mut self,
1002 span: Span,
1003 f: impl FnOnce(&mut Self) -> Result<T, TypeError>,
1004 ) -> Result<T, TypeError> {
1005 if let Some(max) = self.max_infer_depth
1006 && self.infer_depth >= max
1007 {
1008 return Err(TypeError::Spanned {
1009 span,
1010 error: Box::new(TypeError::Internal(format!(
1011 "maximum inference depth exceeded (max {max})"
1012 ))),
1013 });
1014 }
1015 self.infer_depth += 1;
1016 let res = f(self);
1017 self.infer_depth = self.infer_depth.saturating_sub(1);
1018 res
1019 }
1020
1021 fn charge_infer_node(&mut self) -> Result<(), TypeError> {
1022 let Some(gas) = self.gas.as_mut() else {
1023 return Ok(());
1024 };
1025 let cost = gas.costs.infer_node;
1026 gas.charge(cost)?;
1027 Ok(())
1028 }
1029
1030 fn charge_unify_step(&mut self) -> Result<(), TypeError> {
1031 let Some(gas) = self.gas.as_mut() else {
1032 return Ok(());
1033 };
1034 let cost = gas.costs.unify_step;
1035 gas.charge(cost)?;
1036 Ok(())
1037 }
1038
1039 fn bind_var(&mut self, id: TypeVarId, ty: Type) {
1040 if id >= self.subs.len() {
1041 self.subs.resize(id + 1, None);
1042 }
1043 self.subs[id] = Some(ty);
1044 }
1045
1046 fn prune(&mut self, ty: &Type) -> Type {
1047 match ty.as_ref() {
1048 TypeKind::Var(tv) => {
1049 let bound = self.subs.get(tv.id).and_then(|t| t.clone());
1050 match bound {
1051 Some(bound) => {
1052 let pruned = self.prune(&bound);
1053 self.bind_var(tv.id, pruned.clone());
1054 pruned
1055 }
1056 None => ty.clone(),
1057 }
1058 }
1059 TypeKind::Con(_) => ty.clone(),
1060 TypeKind::App(l, r) => {
1061 let l = self.prune(l);
1062 let r = self.prune(r);
1063 Type::app(l, r)
1064 }
1065 TypeKind::Fun(a, b) => {
1066 let a = self.prune(a);
1067 let b = self.prune(b);
1068 Type::fun(a, b)
1069 }
1070 TypeKind::Tuple(ts) => {
1071 Type::new(TypeKind::Tuple(ts.iter().map(|t| self.prune(t)).collect()))
1072 }
1073 TypeKind::Record(fields) => Type::new(TypeKind::Record(
1074 fields
1075 .iter()
1076 .map(|(name, ty)| (name.clone(), self.prune(ty)))
1077 .collect(),
1078 )),
1079 }
1080 }
1081
1082 fn apply_type(&mut self, ty: &Type) -> Type {
1083 self.prune(ty)
1084 }
1085
1086 fn occurs(&mut self, id: TypeVarId, ty: &Type) -> bool {
1087 match self.prune(ty).as_ref() {
1088 TypeKind::Var(tv) => tv.id == id,
1089 TypeKind::Con(_) => false,
1090 TypeKind::App(l, r) => self.occurs(id, l) || self.occurs(id, r),
1091 TypeKind::Fun(a, b) => self.occurs(id, a) || self.occurs(id, b),
1092 TypeKind::Tuple(ts) => ts.iter().any(|t| self.occurs(id, t)),
1093 TypeKind::Record(fields) => fields.iter().any(|(_, ty)| self.occurs(id, ty)),
1094 }
1095 }
1096
1097 fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
1098 self.charge_unify_step()?;
1099 let t1 = self.prune(t1);
1100 let t2 = self.prune(t2);
1101 match (t1.as_ref(), t2.as_ref()) {
1102 (TypeKind::Var(a), TypeKind::Var(b)) if a.id == b.id => Ok(()),
1103 (TypeKind::Var(tv), other) | (other, TypeKind::Var(tv)) => {
1104 if self.occurs(tv.id, &Type::new(other.clone())) {
1105 Err(TypeError::Occurs(
1106 tv.id,
1107 Type::new(other.clone()).to_string(),
1108 ))
1109 } else {
1110 self.bind_var(tv.id, Type::new(other.clone()));
1111 Ok(())
1112 }
1113 }
1114 (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(()),
1115 (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1116 self.unify(l1, l2)?;
1117 self.unify(r1, r2)
1118 }
1119 (TypeKind::Fun(a1, b1), TypeKind::Fun(a2, b2)) => {
1120 self.unify(a1, a2)?;
1121 self.unify(b1, b2)
1122 }
1123 (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1124 if ts1.len() != ts2.len() {
1125 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1126 }
1127 for (a, b) in ts1.iter().zip(ts2.iter()) {
1128 self.unify(a, b)?;
1129 }
1130 Ok(())
1131 }
1132 (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1133 if f1.len() != f2.len() {
1134 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1135 }
1136 for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1137 if n1 != n2 {
1138 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1139 }
1140 self.unify(t1, t2)?;
1141 }
1142 Ok(())
1143 }
1144 (TypeKind::Record(fields), TypeKind::App(head, arg))
1145 | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1146 TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1147 let elem_ty = record_elem_type_unifier(fields, self)?;
1148 self.unify(arg, &elem_ty)
1149 }
1150 TypeKind::Var(tv) => {
1151 self.unify(
1152 &Type::new(TypeKind::Var(tv.clone())),
1153 &Type::builtin(BuiltinTypeId::Dict),
1154 )?;
1155 let elem_ty = record_elem_type_unifier(fields, self)?;
1156 self.unify(arg, &elem_ty)
1157 }
1158 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1159 },
1160 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1161 }
1162 }
1163
1164 fn into_subst(mut self) -> Subst {
1165 let mut out = Subst::new_sync();
1166 for id in 0..self.subs.len() {
1167 if let Some(ty) = self.subs[id].clone() {
1168 let pruned = self.prune(&ty);
1169 out = out.insert(id, pruned);
1170 }
1171 }
1172 out
1173 }
1174}
1175
1176fn record_elem_type_unifier(
1177 fields: &[(Symbol, Type)],
1178 unifier: &mut Unifier<'_>,
1179) -> Result<Type, TypeError> {
1180 let mut iter = fields.iter();
1181 let first = match iter.next() {
1182 Some((_, ty)) => ty.clone(),
1183 None => return Err(TypeError::UnsupportedExpr("empty record")),
1184 };
1185 for (_, ty) in iter {
1186 unifier.unify(&first, ty)?;
1187 }
1188 Ok(unifier.apply_type(&first))
1189}
1190
1191fn bind(tv: &TypeVar, t: &Type) -> Result<Subst, TypeError> {
1192 if let TypeKind::Var(var) = t.as_ref()
1193 && var.id == tv.id
1194 {
1195 return Ok(Subst::new_sync());
1196 }
1197 if t.ftv().contains(&tv.id) {
1198 Err(TypeError::Occurs(tv.id, t.to_string()))
1199 } else {
1200 Ok(Subst::new_sync().insert(tv.id, t.clone()))
1201 }
1202}
1203
1204fn record_elem_type(fields: &[(Symbol, Type)]) -> Result<(Subst, Type), TypeError> {
1205 let mut iter = fields.iter();
1206 let first = match iter.next() {
1207 Some((_, ty)) => ty.clone(),
1208 None => return Err(TypeError::UnsupportedExpr("empty record")),
1209 };
1210 let mut subst = Subst::new_sync();
1211 let mut current = first;
1212 for (_, ty) in iter {
1213 let s_next = unify(¤t.apply(&subst), &ty.apply(&subst))?;
1214 subst = compose_subst(s_next, subst);
1215 current = current.apply(&subst);
1216 }
1217 Ok((subst.clone(), current.apply(&subst)))
1218}
1219
1220pub fn unify(t1: &Type, t2: &Type) -> Result<Subst, TypeError> {
1227 match (t1.as_ref(), t2.as_ref()) {
1228 (TypeKind::Fun(l1, r1), TypeKind::Fun(l2, r2)) => {
1229 let s1 = unify(l1, l2)?;
1230 let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1231 Ok(compose_subst(s2, s1))
1232 }
1233 (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1234 if f1.len() != f2.len() {
1235 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1236 }
1237 let mut subst = Subst::new_sync();
1238 for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1239 if n1 != n2 {
1240 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1241 }
1242 let s_next = unify(&t1.apply(&subst), &t2.apply(&subst))?;
1243 subst = compose_subst(s_next, subst);
1244 }
1245 Ok(subst)
1246 }
1247 (TypeKind::Record(fields), TypeKind::App(head, arg))
1248 | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1249 TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1250 let (s_fields, elem_ty) = record_elem_type(fields)?;
1251 let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1252 Ok(compose_subst(s_arg, s_fields))
1253 }
1254 TypeKind::Var(tv) => {
1255 let s_head = bind(tv, &Type::builtin(BuiltinTypeId::Dict))?;
1256 let arg = arg.apply(&s_head);
1257 let (s_fields, elem_ty) = record_elem_type(fields)?;
1258 let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1259 Ok(compose_subst(s_arg, compose_subst(s_fields, s_head)))
1260 }
1261 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1262 },
1263 (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1264 let s1 = unify(l1, l2)?;
1265 let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1266 Ok(compose_subst(s2, s1))
1267 }
1268 (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1269 if ts1.len() != ts2.len() {
1270 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1271 }
1272 let mut s = Subst::new_sync();
1273 for (a, b) in ts1.iter().zip(ts2.iter()) {
1274 let s_next = unify(&a.apply(&s), &b.apply(&s))?;
1275 s = compose_subst(s_next, s);
1276 }
1277 Ok(s)
1278 }
1279 (TypeKind::Var(tv), t) | (t, TypeKind::Var(tv)) => bind(tv, &Type::new(t.clone())),
1280 (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(Subst::new_sync()),
1281 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1282 }
1283}
1284
1285#[derive(Default, Debug, Clone)]
1286pub struct TypeEnv {
1287 pub values: HashTrieMapSync<Symbol, Vec<Scheme>>,
1288}
1289
1290impl TypeEnv {
1291 pub fn new() -> Self {
1292 Self {
1293 values: HashTrieMapSync::new_sync(),
1294 }
1295 }
1296
1297 pub fn extend(&mut self, name: Symbol, scheme: Scheme) {
1298 self.values = self.values.insert(name, vec![scheme]);
1299 }
1300
1301 pub fn extend_overload(&mut self, name: Symbol, scheme: Scheme) {
1302 let mut schemes = self.values.get(&name).cloned().unwrap_or_default();
1303 schemes.push(scheme);
1304 self.values = self.values.insert(name, schemes);
1305 }
1306
1307 pub fn remove(&mut self, name: &Symbol) {
1308 self.values = self.values.remove(name);
1309 }
1310
1311 pub fn lookup(&self, name: &Symbol) -> Option<&[Scheme]> {
1312 self.values.get(name).map(|schemes| schemes.as_slice())
1313 }
1314}
1315
1316impl Types for TypeEnv {
1317 fn apply(&self, s: &Subst) -> Self {
1318 let mut values = HashTrieMapSync::new_sync();
1319 for (k, v) in self.values.iter() {
1320 let updated = v
1321 .iter()
1322 .map(|scheme| {
1323 if scheme.vars.is_empty() && !subst_is_empty(s) {
1326 scheme.apply(s)
1327 } else {
1328 scheme.clone()
1329 }
1330 })
1331 .collect();
1332 values = values.insert(k.clone(), updated);
1333 }
1334 TypeEnv { values }
1335 }
1336
1337 fn ftv(&self) -> HashSet<TypeVarId> {
1338 self.values
1339 .iter()
1340 .flat_map(|(_, schemes)| schemes.iter().flat_map(Types::ftv))
1341 .collect()
1342 }
1343}
1344
1345#[derive(Default, Debug, Clone)]
1346pub struct TypeVarSupply {
1347 counter: TypeVarId,
1348}
1349
1350impl TypeVarSupply {
1351 pub fn new() -> Self {
1352 Self { counter: 0 }
1353 }
1354
1355 pub fn fresh(&mut self, name_hint: impl Into<Option<Symbol>>) -> TypeVar {
1356 let tv = TypeVar::new(self.counter, name_hint.into());
1357 self.counter += 1;
1358 tv
1359 }
1360}
1361
1362fn is_integral_literal_expr(expr: &Expr) -> bool {
1363 matches!(expr, Expr::Int(..) | Expr::Uint(..))
1364}
1365
1366pub fn generalize(env: &TypeEnv, preds: Vec<Predicate>, typ: Type) -> Scheme {
1369 let mut vars: Vec<TypeVar> = typ
1370 .ftv()
1371 .union(&preds.ftv())
1372 .copied()
1373 .collect::<HashSet<_>>()
1374 .difference(&env.ftv())
1375 .cloned()
1376 .map(|id| TypeVar::new(id, None))
1377 .collect();
1378 vars.sort_by_key(|v| v.id);
1379 Scheme::new(vars, preds, typ)
1380}
1381
1382pub fn instantiate(scheme: &Scheme, supply: &mut TypeVarSupply) -> (Vec<Predicate>, Type) {
1383 let mut subst = Subst::new_sync();
1386 for v in &scheme.vars {
1387 subst = subst.insert(v.id, Type::var(supply.fresh(v.name.clone())));
1388 }
1389 (scheme.preds.apply(&subst), scheme.typ.apply(&subst))
1390}
1391
1392#[derive(Clone, Debug)]
1394pub struct AdtParam {
1395 pub name: Symbol,
1396 pub var: TypeVar,
1397}
1398
1399#[derive(Clone, Debug)]
1401pub struct AdtVariant {
1402 pub name: Symbol,
1403 pub args: Vec<Type>,
1404}
1405
1406#[derive(Clone, Debug)]
1412pub struct AdtDecl {
1413 pub name: Symbol,
1414 pub params: Vec<AdtParam>,
1415 pub variants: Vec<AdtVariant>,
1416}
1417
1418impl AdtDecl {
1419 pub fn new(name: &Symbol, param_names: &[Symbol], supply: &mut TypeVarSupply) -> Self {
1420 let params = param_names
1421 .iter()
1422 .map(|p| AdtParam {
1423 name: p.clone(),
1424 var: supply.fresh(Some(p.clone())),
1425 })
1426 .collect();
1427 Self {
1428 name: name.clone(),
1429 params,
1430 variants: Vec::new(),
1431 }
1432 }
1433
1434 pub fn param_type(&self, name: &Symbol) -> Option<Type> {
1435 self.params
1436 .iter()
1437 .find(|p| &p.name == name)
1438 .map(|p| Type::var(p.var.clone()))
1439 }
1440
1441 pub fn add_variant(&mut self, name: Symbol, args: Vec<Type>) {
1442 self.variants.push(AdtVariant { name, args });
1443 }
1444
1445 pub fn result_type(&self) -> Type {
1446 let mut ty = Type::con(&self.name, self.params.len());
1447 for param in &self.params {
1448 ty = Type::app(ty, Type::var(param.var.clone()));
1449 }
1450 ty
1451 }
1452
1453 pub fn constructor_schemes(&self) -> Vec<(Symbol, Scheme)> {
1456 let result_ty = self.result_type();
1457 let vars: Vec<TypeVar> = self.params.iter().map(|p| p.var.clone()).collect();
1458 let mut out = Vec::new();
1459 for variant in &self.variants {
1460 let mut typ = result_ty.clone();
1461 for arg in variant.args.iter().rev() {
1462 typ = Type::fun(arg.clone(), typ);
1463 }
1464 out.push((variant.name.clone(), Scheme::new(vars.clone(), vec![], typ)));
1465 }
1466 out
1467 }
1468}
1469
1470#[derive(Clone, Debug)]
1471pub struct Class {
1472 pub supers: Vec<Symbol>,
1473}
1474
1475impl Class {
1476 pub fn new(supers: Vec<Symbol>) -> Self {
1477 Self { supers }
1478 }
1479}
1480
1481#[derive(Clone, Debug)]
1482pub struct Instance {
1483 pub context: Vec<Predicate>,
1484 pub head: Predicate,
1485}
1486
1487impl Instance {
1488 pub fn new(context: Vec<Predicate>, head: Predicate) -> Self {
1489 Self { context, head }
1490 }
1491}
1492
1493#[derive(Default, Debug, Clone)]
1494pub struct ClassEnv {
1495 pub classes: HashMap<Symbol, Class>,
1496 pub instances: HashMap<Symbol, Vec<Instance>>,
1497}
1498
1499impl ClassEnv {
1500 pub fn new() -> Self {
1501 Self {
1502 classes: HashMap::new(),
1503 instances: HashMap::new(),
1504 }
1505 }
1506
1507 pub fn add_class(&mut self, name: Symbol, supers: Vec<Symbol>) {
1508 self.classes.insert(name, Class::new(supers));
1509 }
1510
1511 pub fn add_instance(&mut self, class: Symbol, inst: Instance) {
1512 self.instances.entry(class).or_default().push(inst);
1513 }
1514
1515 pub fn supers_of(&self, class: &Symbol) -> Vec<Symbol> {
1516 self.classes
1517 .get(class)
1518 .map(|c| c.supers.clone())
1519 .unwrap_or_default()
1520 }
1521}
1522
1523pub fn entails(
1524 class_env: &ClassEnv,
1525 given: &[Predicate],
1526 pred: &Predicate,
1527) -> Result<bool, TypeError> {
1528 let mut closure: Vec<Predicate> = given.to_vec();
1530 let mut i = 0;
1531 while i < closure.len() {
1532 let p = closure[i].clone();
1533 for sup in class_env.supers_of(&p.class) {
1534 closure.push(Predicate::new(sup, p.typ.clone()));
1535 }
1536 i += 1;
1537 }
1538
1539 if closure
1540 .iter()
1541 .any(|p| p.class == pred.class && p.typ == pred.typ)
1542 {
1543 return Ok(true);
1544 }
1545
1546 if !class_env.classes.contains_key(&pred.class) {
1547 return Err(TypeError::UnknownClass(pred.class.clone()));
1548 }
1549
1550 if let Some(instances) = class_env.instances.get(&pred.class) {
1551 for inst in instances {
1552 if let Ok(s) = unify(&inst.head.typ, &pred.typ) {
1553 let ctx = inst.context.apply(&s);
1554 if ctx
1555 .iter()
1556 .all(|c| entails(class_env, &closure, c).unwrap_or(false))
1557 {
1558 return Ok(true);
1559 }
1560 }
1561 }
1562 }
1563 Ok(false)
1564}
1565
1566#[derive(Default, Debug, Clone)]
1567pub struct TypeSystem {
1568 pub env: TypeEnv,
1569 pub classes: ClassEnv,
1570 pub adts: HashMap<Symbol, AdtDecl>,
1571 pub class_info: HashMap<Symbol, ClassInfo>,
1572 pub class_methods: HashMap<Symbol, ClassMethodInfo>,
1573 pub declared_values: HashSet<Symbol>,
1578 pub supply: TypeVarSupply,
1579 limits: TypeSystemLimits,
1580}
1581
1582#[derive(Clone, Debug)]
1592pub struct ClassInfo {
1593 pub name: Symbol,
1594 pub params: Vec<Symbol>,
1595 pub supers: Vec<Symbol>,
1596 pub methods: BTreeMap<Symbol, Scheme>,
1597}
1598
1599#[derive(Clone, Debug)]
1600pub struct ClassMethodInfo {
1601 pub class: Symbol,
1602 pub scheme: Scheme,
1603}
1604
1605#[derive(Clone, Debug)]
1606pub struct PreparedInstanceDecl {
1607 pub span: Span,
1608 pub class: Symbol,
1609 pub head: Type,
1610 pub context: Vec<Predicate>,
1611}
1612
1613impl TypeSystem {
1614 pub fn new() -> Self {
1615 Self {
1616 env: TypeEnv::new(),
1617 classes: ClassEnv::new(),
1618 adts: HashMap::new(),
1619 class_info: HashMap::new(),
1620 class_methods: HashMap::new(),
1621 declared_values: HashSet::new(),
1622 supply: TypeVarSupply::new(),
1623 limits: TypeSystemLimits::default(),
1624 }
1625 }
1626
1627 pub fn fresh_type_var(&mut self, name: Option<Symbol>) -> TypeVar {
1628 self.supply.fresh(name)
1629 }
1630
1631 pub fn set_limits(&mut self, limits: TypeSystemLimits) {
1632 self.limits = limits;
1633 }
1634
1635 pub fn new_with_prelude() -> Result<Self, TypeError> {
1636 let mut ts = TypeSystem::new();
1637 prelude::build_prelude(&mut ts)?;
1638 Ok(ts)
1639 }
1640
1641 fn register_decl(&mut self, decl: &Decl) -> Result<(), TypeError> {
1642 match decl {
1643 Decl::Type(ty) => self.register_type_decl(ty),
1644 Decl::Class(class_decl) => self.register_class_decl(class_decl),
1645 Decl::Instance(inst_decl) => {
1646 let _ = self.register_instance_decl(inst_decl)?;
1647 Ok(())
1648 }
1649 Decl::Fn(fd) => self.register_fn_decls(std::slice::from_ref(fd)),
1650 Decl::DeclareFn(fd) => self.inject_declare_fn_decl(fd),
1651 Decl::Import(..) => Ok(()),
1652 }
1653 }
1654
1655 pub fn register_decls(&mut self, decls: &[Decl]) -> Result<(), TypeError> {
1656 let mut pending_fns: Vec<FnDecl> = Vec::new();
1657 for decl in decls {
1658 if let Decl::Fn(fd) = decl {
1659 pending_fns.push(fd.clone());
1660 continue;
1661 }
1662
1663 if !pending_fns.is_empty() {
1664 self.register_fn_decls(&pending_fns)?;
1665 pending_fns.clear();
1666 }
1667
1668 self.register_decl(decl)?;
1669 }
1670 if !pending_fns.is_empty() {
1671 self.register_fn_decls(&pending_fns)?;
1672 }
1673 Ok(())
1674 }
1675
1676 pub fn add_value(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1677 let name = sym(name.as_ref());
1678 self.declared_values.remove(&name);
1679 self.env.extend(name, scheme);
1680 }
1681
1682 pub fn add_overload(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1683 let name = sym(name.as_ref());
1684 self.declared_values.remove(&name);
1685 self.env.extend_overload(name, scheme);
1686 }
1687
1688 pub fn register_instance(&mut self, class: impl AsRef<str>, inst: Instance) {
1689 self.classes.add_instance(sym(class.as_ref()), inst);
1690 }
1691
1692 pub fn register_class_decl(&mut self, decl: &ClassDecl) -> Result<(), TypeError> {
1693 let span = decl.span;
1694 (|| {
1695 if self.class_info.contains_key(&decl.name)
1699 || self.classes.classes.contains_key(&decl.name)
1700 {
1701 return Err(TypeError::DuplicateClass(decl.name.clone()));
1702 }
1703 if decl.params.is_empty() {
1704 return Err(TypeError::InvalidClassArity {
1705 class: decl.name.clone(),
1706 got: decl.params.len(),
1707 });
1708 }
1709 let params = decl.params.clone();
1710
1711 let mut supers = Vec::with_capacity(decl.supers.len());
1717 if !decl.supers.is_empty() && params.len() != 1 {
1718 return Err(TypeError::UnsupportedExpr(
1719 "multi-parameter classes cannot declare superclasses yet",
1720 ));
1721 }
1722 for sup in &decl.supers {
1723 let mut vars = HashMap::new();
1724 let param = params[0].clone();
1725 let param_tv = self.supply.fresh(Some(param.clone()));
1726 vars.insert(param, param_tv.clone());
1727 let sup_ty = type_from_annotation_expr_vars(
1728 &self.adts,
1729 &sup.typ,
1730 &mut vars,
1731 &mut self.supply,
1732 )?;
1733 if sup_ty != Type::var(param_tv) {
1734 return Err(TypeError::UnsupportedExpr(
1735 "superclass constraints must be of the form `<= C a`",
1736 ));
1737 }
1738 supers.push(sup.class.to_dotted_symbol());
1739 }
1740
1741 self.classes.add_class(decl.name.clone(), supers.clone());
1742
1743 let mut methods = BTreeMap::new();
1744 for ClassMethodSig { name, typ } in &decl.methods {
1745 if self.env.lookup(name).is_some() || self.class_methods.contains_key(name) {
1746 return Err(TypeError::DuplicateClassMethod(name.clone()));
1747 }
1748
1749 let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1750 let mut param_tvs: Vec<TypeVar> = Vec::with_capacity(params.len());
1751 for param in ¶ms {
1752 let tv = self.supply.fresh(Some(param.clone()));
1753 vars.insert(param.clone(), tv.clone());
1754 param_tvs.push(tv);
1755 }
1756
1757 let ty =
1758 type_from_annotation_expr_vars(&self.adts, typ, &mut vars, &mut self.supply)?;
1759
1760 let mut scheme_vars: Vec<TypeVar> = vars.values().cloned().collect();
1761 scheme_vars.sort_by_key(|tv| tv.id);
1762 scheme_vars.dedup_by_key(|tv| tv.id);
1763
1764 let class_pred = Predicate {
1765 class: decl.name.clone(),
1766 typ: if param_tvs.len() == 1 {
1767 Type::var(param_tvs[0].clone())
1768 } else {
1769 Type::tuple(param_tvs.into_iter().map(Type::var).collect())
1770 },
1771 };
1772 let scheme = Scheme::new(scheme_vars, vec![class_pred], ty);
1773
1774 self.env.extend(name.clone(), scheme.clone());
1775 self.class_methods.insert(
1776 name.clone(),
1777 ClassMethodInfo {
1778 class: decl.name.clone(),
1779 scheme: scheme.clone(),
1780 },
1781 );
1782 methods.insert(name.clone(), scheme);
1783 }
1784
1785 self.class_info.insert(
1786 decl.name.clone(),
1787 ClassInfo {
1788 name: decl.name.clone(),
1789 params,
1790 supers,
1791 methods,
1792 },
1793 );
1794 Ok(())
1795 })()
1796 .map_err(|err| with_span(&span, err))
1797 }
1798
1799 pub fn register_instance_decl(
1800 &mut self,
1801 decl: &InstanceDecl,
1802 ) -> Result<PreparedInstanceDecl, TypeError> {
1803 let span = decl.span;
1804 (|| {
1805 let class = decl.class.clone();
1806 if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
1807 return Err(TypeError::UnknownClass(class));
1808 }
1809
1810 let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1811 let head = type_from_annotation_expr_vars(
1812 &self.adts,
1813 &decl.head,
1814 &mut vars,
1815 &mut self.supply,
1816 )?;
1817 let context = predicates_from_constraints(
1818 &self.adts,
1819 &decl.context,
1820 &mut vars,
1821 &mut self.supply,
1822 )?;
1823
1824 let inst = Instance::new(
1825 context.clone(),
1826 Predicate {
1827 class: decl.class.clone(),
1828 typ: head.clone(),
1829 },
1830 );
1831
1832 if let Some(info) = self.class_info.get(&decl.class) {
1834 for method in &decl.methods {
1835 if !info.methods.contains_key(&method.name) {
1836 return Err(TypeError::UnknownInstanceMethod {
1837 class: decl.class.clone(),
1838 method: method.name.clone(),
1839 });
1840 }
1841 }
1842 for method_name in info.methods.keys() {
1843 if !decl.methods.iter().any(|m| &m.name == method_name) {
1844 return Err(TypeError::MissingInstanceMethod {
1845 class: decl.class.clone(),
1846 method: method_name.clone(),
1847 });
1848 }
1849 }
1850 }
1851
1852 self.classes.add_instance(decl.class.clone(), inst);
1853 Ok(PreparedInstanceDecl {
1854 span,
1855 class: decl.class.clone(),
1856 head,
1857 context,
1858 })
1859 })()
1860 .map_err(|err| with_span(&span, err))
1861 }
1862
1863 pub fn prepare_instance_decl(
1864 &mut self,
1865 decl: &InstanceDecl,
1866 ) -> Result<PreparedInstanceDecl, TypeError> {
1867 let span = decl.span;
1868 (|| {
1869 let class = decl.class.clone();
1870 if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
1871 return Err(TypeError::UnknownClass(class));
1872 }
1873
1874 let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1875 let head = type_from_annotation_expr_vars(
1876 &self.adts,
1877 &decl.head,
1878 &mut vars,
1879 &mut self.supply,
1880 )?;
1881 let context = predicates_from_constraints(
1882 &self.adts,
1883 &decl.context,
1884 &mut vars,
1885 &mut self.supply,
1886 )?;
1887
1888 if let Some(info) = self.class_info.get(&decl.class) {
1890 for method in &decl.methods {
1891 if !info.methods.contains_key(&method.name) {
1892 return Err(TypeError::UnknownInstanceMethod {
1893 class: decl.class.clone(),
1894 method: method.name.clone(),
1895 });
1896 }
1897 }
1898 for method_name in info.methods.keys() {
1899 if !decl.methods.iter().any(|m| &m.name == method_name) {
1900 return Err(TypeError::MissingInstanceMethod {
1901 class: decl.class.clone(),
1902 method: method_name.clone(),
1903 });
1904 }
1905 }
1906 }
1907
1908 Ok(PreparedInstanceDecl {
1909 span,
1910 class: decl.class.clone(),
1911 head,
1912 context,
1913 })
1914 })()
1915 .map_err(|err| with_span(&span, err))
1916 }
1917
1918 pub fn register_fn_decls(&mut self, decls: &[FnDecl]) -> Result<(), TypeError> {
1919 if decls.is_empty() {
1920 return Ok(());
1921 }
1922
1923 let saved_env = self.env.clone();
1924 let saved_declared = self.declared_values.clone();
1925
1926 let result: Result<(), TypeError> = (|| {
1927 #[derive(Clone)]
1928 struct FnInfo {
1929 decl: FnDecl,
1930 expected: Type,
1931 declared_preds: Vec<Predicate>,
1932 scheme: Scheme,
1933 ann_vars: HashMap<Symbol, TypeVar>,
1934 }
1935
1936 let mut infos: Vec<FnInfo> = Vec::with_capacity(decls.len());
1937 let mut seen_names = HashSet::new();
1938
1939 for decl in decls {
1940 let span = decl.span;
1941 let info = (|| {
1942 let name = &decl.name.name;
1943 if !seen_names.insert(name.clone()) {
1944 return Err(TypeError::DuplicateValue(name.clone()));
1945 }
1946
1947 if self.env.lookup(name).is_some() {
1948 if self.declared_values.remove(name) {
1949 self.env.remove(name);
1951 } else {
1952 return Err(TypeError::DuplicateValue(name.clone()));
1953 }
1954 }
1955
1956 let mut sig = decl.ret.clone();
1957 for (_, ann) in decl.params.iter().rev() {
1958 let span = Span::from_begin_end(ann.span().begin, sig.span().end);
1959 sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
1960 }
1961
1962 let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
1963 let expected = type_from_annotation_expr_vars(
1964 &self.adts,
1965 &sig,
1966 &mut ann_vars,
1967 &mut self.supply,
1968 )?;
1969 let declared_preds = predicates_from_constraints(
1970 &self.adts,
1971 &decl.constraints,
1972 &mut ann_vars,
1973 &mut self.supply,
1974 )?;
1975
1976 let var_arities: HashMap<TypeVarId, usize> = ann_vars
1978 .values()
1979 .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
1980 .collect();
1981 for pred in &declared_preds {
1982 let _ = entails(&self.classes, &[], pred)?;
1983 let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
1984 else {
1985 continue;
1986 };
1987 let args: Vec<Type> = if expected_arities.len() == 1 {
1988 vec![pred.typ.clone()]
1989 } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
1990 if parts.len() != expected_arities.len() {
1991 continue;
1992 }
1993 parts.clone()
1994 } else {
1995 continue;
1996 };
1997
1998 for (arg, expected_arity) in
1999 args.iter().zip(expected_arities.iter().copied())
2000 {
2001 let got =
2002 type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2003 TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2004 _ => None,
2005 });
2006 let Some(got) = got else {
2007 continue;
2008 };
2009 if got != expected_arity {
2010 return Err(TypeError::KindMismatch {
2011 class: pred.class.clone(),
2012 expected: expected_arity,
2013 got,
2014 typ: arg.to_string(),
2015 });
2016 }
2017 }
2018 }
2019
2020 let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2021 vars.sort_by_key(|v| v.id);
2022 let scheme = Scheme::new(vars, declared_preds.clone(), expected.clone());
2023 reject_ambiguous_scheme(&scheme)?;
2024
2025 Ok(FnInfo {
2026 decl: decl.clone(),
2027 expected,
2028 declared_preds,
2029 scheme,
2030 ann_vars,
2031 })
2032 })();
2033
2034 infos.push(info.map_err(|err| with_span(&span, err))?);
2035 }
2036
2037 for info in &infos {
2040 self.env
2041 .extend(info.decl.name.name.clone(), info.scheme.clone());
2042 }
2043
2044 for info in infos {
2045 let span = info.decl.span;
2046 let mut lam_body = info.decl.body.clone();
2047 let mut lam_end = lam_body.span().end;
2048 for (param, ann) in info.decl.params.iter().rev() {
2049 let lam_constraints = Vec::new();
2050 let span = Span::from_begin_end(param.span.begin, lam_end);
2051 lam_body = Arc::new(Expr::Lam(
2052 span,
2053 Scope::new_sync(),
2054 param.clone(),
2055 Some(ann.clone()),
2056 lam_constraints,
2057 lam_body,
2058 ));
2059 lam_end = lam_body.span().end;
2060 }
2061
2062 let (typed, preds, inferred) = infer_typed(self, lam_body.as_ref())?;
2063 let s = unify(&inferred, &info.expected)?;
2064 let preds = preds.apply(&s);
2065 let inferred = inferred.apply(&s);
2066 let declared_preds = info.declared_preds.apply(&s);
2067 let expected = info.expected.apply(&s);
2068
2069 let var_arities: HashMap<TypeVarId, usize> = info
2071 .ann_vars
2072 .values()
2073 .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
2074 .collect();
2075 for pred in &declared_preds {
2076 let _ = entails(&self.classes, &[], pred)?;
2077 let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
2078 else {
2079 continue;
2080 };
2081 let args: Vec<Type> = if expected_arities.len() == 1 {
2082 vec![pred.typ.clone()]
2083 } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2084 if parts.len() != expected_arities.len() {
2085 continue;
2086 }
2087 parts.clone()
2088 } else {
2089 continue;
2090 };
2091
2092 for (arg, expected_arity) in args.iter().zip(expected_arities.iter().copied()) {
2093 let got = type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2094 TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2095 _ => None,
2096 });
2097 let Some(got) = got else {
2098 continue;
2099 };
2100 if got != expected_arity {
2101 return Err(with_span(
2102 &span,
2103 TypeError::KindMismatch {
2104 class: pred.class.clone(),
2105 expected: expected_arity,
2106 got,
2107 typ: arg.to_string(),
2108 },
2109 ));
2110 }
2111 }
2112 }
2113
2114 check_non_ground_predicates_declared(&self.classes, &declared_preds, &preds)
2115 .map_err(|err| with_span(&span, err))?;
2116
2117 let _ = inferred;
2118 let _ = typed;
2119 }
2120
2121 Ok(())
2122 })();
2123
2124 if result.is_err() {
2125 self.env = saved_env;
2126 self.declared_values = saved_declared;
2127 }
2128 result
2129 }
2130
2131 pub fn inject_declare_fn_decl(&mut self, decl: &DeclareFnDecl) -> Result<(), TypeError> {
2132 let span = decl.span;
2133 (|| {
2134 let mut sig = decl.ret.clone();
2136 for (_, ann) in decl.params.iter().rev() {
2137 let span = Span::from_begin_end(ann.span().begin, sig.span().end);
2138 sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
2139 }
2140
2141 let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
2142 let expected =
2143 type_from_annotation_expr_vars(&self.adts, &sig, &mut ann_vars, &mut self.supply)?;
2144 let declared_preds = predicates_from_constraints(
2145 &self.adts,
2146 &decl.constraints,
2147 &mut ann_vars,
2148 &mut self.supply,
2149 )?;
2150
2151 let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2152 vars.sort_by_key(|v| v.id);
2153 let scheme = Scheme::new(vars, declared_preds, expected);
2154 reject_ambiguous_scheme(&scheme)?;
2155
2156 for pred in &scheme.preds {
2158 let _ = entails(&self.classes, &[], pred)?;
2159 }
2160
2161 let name = &decl.name.name;
2162
2163 if self.env.lookup(name).is_some() && !self.declared_values.contains(name) {
2166 return Ok(());
2167 }
2168
2169 if let Some(existing) = self.env.lookup(name) {
2170 if existing.iter().any(|s| scheme_compatible(s, &scheme)) {
2171 return Ok(());
2172 }
2173 return Err(TypeError::DuplicateValue(decl.name.name.clone()));
2174 }
2175
2176 self.env.extend(decl.name.name.clone(), scheme);
2177 self.declared_values.insert(decl.name.name.clone());
2178 Ok(())
2179 })()
2180 .map_err(|err| with_span(&span, err))
2181 }
2182
2183 pub fn instantiate_class_method_for_head(
2184 &mut self,
2185 class: &Symbol,
2186 method: &Symbol,
2187 head: &Type,
2188 ) -> Result<Type, TypeError> {
2189 let info = self
2190 .class_info
2191 .get(class)
2192 .ok_or_else(|| TypeError::UnknownClass(class.clone()))?;
2193 let scheme = info
2194 .methods
2195 .get(method)
2196 .ok_or_else(|| TypeError::UnknownInstanceMethod {
2197 class: class.clone(),
2198 method: method.clone(),
2199 })?;
2200
2201 let (preds, typ) = instantiate(scheme, &mut self.supply);
2202 let class_pred =
2203 preds
2204 .iter()
2205 .find(|p| &p.class == class)
2206 .ok_or(TypeError::UnsupportedExpr(
2207 "class method scheme missing class predicate",
2208 ))?;
2209 let s = unify(&class_pred.typ, head)?;
2210 Ok(typ.apply(&s))
2211 }
2212
2213 pub fn typecheck_instance_method(
2214 &mut self,
2215 prepared: &PreparedInstanceDecl,
2216 method: &InstanceMethodImpl,
2217 ) -> Result<TypedExpr, TypeError> {
2218 let expected =
2219 self.instantiate_class_method_for_head(&prepared.class, &method.name, &prepared.head)?;
2220 let (typed, preds, actual) = infer_typed(self, method.body.as_ref())?;
2221 let s = unify(&actual, &expected)?;
2222 let typed = typed.apply(&s);
2223 let preds = preds.apply(&s);
2224
2225 let mut given = prepared.context.clone();
2231
2232 given.push(Predicate::new(
2235 prepared.class.clone(),
2236 prepared.head.clone(),
2237 ));
2238 let mut i = 0;
2239 while i < given.len() {
2240 let p = given[i].clone();
2241 for sup in self.classes.supers_of(&p.class) {
2242 given.push(Predicate::new(sup, p.typ.clone()));
2243 }
2244 i += 1;
2245 }
2246
2247 for pred in &preds {
2248 if pred.typ.ftv().is_empty() {
2249 if !entails(&self.classes, &given, pred)? {
2250 return Err(TypeError::NoInstance(
2251 pred.class.clone(),
2252 pred.typ.to_string(),
2253 ));
2254 }
2255 } else if !given
2256 .iter()
2257 .any(|p| p.class == pred.class && p.typ == pred.typ)
2258 {
2259 return Err(TypeError::MissingInstanceConstraint {
2260 method: method.name.clone(),
2261 class: pred.class.clone(),
2262 typ: pred.typ.to_string(),
2263 });
2264 }
2265 }
2266
2267 Ok(typed)
2268 }
2269
2270 pub fn register_adt(&mut self, adt: &AdtDecl) {
2274 self.adts.insert(adt.name.clone(), adt.clone());
2275 for (name, scheme) in adt.constructor_schemes() {
2276 self.register_value_scheme(&name, scheme);
2277 }
2278 }
2279
2280 pub fn adt_from_decl(&mut self, decl: &TypeDecl) -> Result<AdtDecl, TypeError> {
2281 let mut adt = AdtDecl::new(&decl.name, &decl.params, &mut self.supply);
2282 let mut param_map: HashMap<Symbol, TypeVar> = HashMap::new();
2283 for param in &adt.params {
2284 param_map.insert(param.name.clone(), param.var.clone());
2285 }
2286
2287 for variant in &decl.variants {
2288 let mut args = Vec::new();
2289 for arg in &variant.args {
2290 let ty = self.type_from_expr(decl, ¶m_map, arg)?;
2291 args.push(ty);
2292 }
2293 adt.add_variant(variant.name.clone(), args);
2294 }
2295 Ok(adt)
2296 }
2297
2298 pub fn register_type_decl(&mut self, decl: &TypeDecl) -> Result<(), TypeError> {
2299 if BuiltinTypeId::from_symbol(&decl.name).is_some() {
2300 return Err(TypeError::ReservedTypeName(decl.name.clone()));
2301 }
2302 let adt = self.adt_from_decl(decl)?;
2303 self.register_adt(&adt);
2304 Ok(())
2305 }
2306
2307 fn type_from_expr(
2308 &mut self,
2309 decl: &TypeDecl,
2310 params: &HashMap<Symbol, TypeVar>,
2311 expr: &TypeExpr,
2312 ) -> Result<Type, TypeError> {
2313 let span = *expr.span();
2314 let res = (|| match expr {
2315 TypeExpr::Name(_, name) => {
2316 let name_sym = name.to_dotted_symbol();
2317 if let Some(tv) = params.get(&name_sym) {
2318 Ok(Type::var(tv.clone()))
2319 } else {
2320 let name = normalize_type_name(&name_sym);
2321 if let Some(arity) = self.type_arity(decl, &name) {
2322 Ok(Type::con(name, arity))
2323 } else {
2324 Err(TypeError::UnknownTypeName(name))
2325 }
2326 }
2327 }
2328 TypeExpr::App(_, fun, arg) => {
2329 let fty = self.type_from_expr(decl, params, fun)?;
2330 let aty = self.type_from_expr(decl, params, arg)?;
2331 Ok(type_app_with_result_syntax(fty, aty))
2332 }
2333 TypeExpr::Fun(_, arg, ret) => {
2334 let arg_ty = self.type_from_expr(decl, params, arg)?;
2335 let ret_ty = self.type_from_expr(decl, params, ret)?;
2336 Ok(Type::fun(arg_ty, ret_ty))
2337 }
2338 TypeExpr::Tuple(_, elems) => {
2339 let mut out = Vec::new();
2340 for elem in elems {
2341 out.push(self.type_from_expr(decl, params, elem)?);
2342 }
2343 Ok(Type::tuple(out))
2344 }
2345 TypeExpr::Record(_, fields) => {
2346 let mut out = Vec::new();
2347 for (name, ty) in fields {
2348 out.push((name.clone(), self.type_from_expr(decl, params, ty)?));
2349 }
2350 Ok(Type::record(out))
2351 }
2352 })();
2353 res.map_err(|err| with_span(&span, err))
2354 }
2355
2356 fn type_arity(&self, decl: &TypeDecl, name: &Symbol) -> Option<usize> {
2357 if &decl.name == name {
2358 return Some(decl.params.len());
2359 }
2360 if let Some(adt) = self.adts.get(name) {
2361 return Some(adt.params.len());
2362 }
2363 BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2364 }
2365
2366 fn register_value_scheme(&mut self, name: &Symbol, scheme: Scheme) {
2367 match self.env.lookup(name) {
2368 None => self.env.extend(name.clone(), scheme),
2369 Some(existing) => {
2370 if existing.iter().any(|s| unify(&s.typ, &scheme.typ).is_ok()) {
2371 return;
2372 }
2373 self.env.extend_overload(name.clone(), scheme);
2374 }
2375 }
2376 }
2377
2378 fn expected_class_param_arities(&self, class: &Symbol) -> Option<Vec<usize>> {
2379 let info = self.class_info.get(class)?;
2380 let mut out = vec![0usize; info.params.len()];
2381 for scheme in info.methods.values() {
2382 for (idx, param) in info.params.iter().enumerate() {
2383 let Some(tv) = scheme.vars.iter().find(|v| v.name.as_ref() == Some(param)) else {
2384 continue;
2385 };
2386 out[idx] = out[idx].max(max_head_app_arity_for_var(&scheme.typ, tv.id));
2387 }
2388 }
2389 Some(out)
2390 }
2391
2392 fn check_predicate_kind(&self, pred: &Predicate) -> Result<(), TypeError> {
2393 let Some(expected) = self.expected_class_param_arities(&pred.class) else {
2394 return Ok(());
2396 };
2397
2398 let args: Vec<Type> = if expected.len() == 1 {
2399 vec![pred.typ.clone()]
2400 } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2401 if parts.len() != expected.len() {
2402 return Ok(());
2403 }
2404 parts.clone()
2405 } else {
2406 return Ok(());
2407 };
2408
2409 for (arg, expected_arity) in args.iter().zip(expected.iter().copied()) {
2410 let Some(got) = type_term_remaining_arity(arg) else {
2411 continue;
2415 };
2416 if got != expected_arity {
2417 return Err(TypeError::KindMismatch {
2418 class: pred.class.clone(),
2419 expected: expected_arity,
2420 got,
2421 typ: arg.to_string(),
2422 });
2423 }
2424 }
2425 Ok(())
2426 }
2427
2428 fn check_predicate_kinds(&self, preds: &[Predicate]) -> Result<(), TypeError> {
2429 for pred in preds {
2430 self.check_predicate_kind(pred)?;
2431 }
2432 Ok(())
2433 }
2434}
2435
2436fn type_from_annotation_expr(
2437 adts: &HashMap<Symbol, AdtDecl>,
2438 expr: &TypeExpr,
2439) -> Result<Type, TypeError> {
2440 let span = *expr.span();
2441 let res = (|| match expr {
2442 TypeExpr::Name(_, name) => {
2443 let name = normalize_type_name(&name.to_dotted_symbol());
2444 match annotation_type_arity(adts, &name) {
2445 Some(arity) => Ok(Type::con(name, arity)),
2446 None => Err(TypeError::UnknownTypeName(name)),
2447 }
2448 }
2449 TypeExpr::App(_, fun, arg) => {
2450 let fty = type_from_annotation_expr(adts, fun)?;
2451 let aty = type_from_annotation_expr(adts, arg)?;
2452 Ok(type_app_with_result_syntax(fty, aty))
2453 }
2454 TypeExpr::Fun(_, arg, ret) => {
2455 let arg_ty = type_from_annotation_expr(adts, arg)?;
2456 let ret_ty = type_from_annotation_expr(adts, ret)?;
2457 Ok(Type::fun(arg_ty, ret_ty))
2458 }
2459 TypeExpr::Tuple(_, elems) => {
2460 let mut out = Vec::new();
2461 for elem in elems {
2462 out.push(type_from_annotation_expr(adts, elem)?);
2463 }
2464 Ok(Type::tuple(out))
2465 }
2466 TypeExpr::Record(_, fields) => {
2467 let mut out = Vec::new();
2468 for (name, ty) in fields {
2469 out.push((name.clone(), type_from_annotation_expr(adts, ty)?));
2470 }
2471 Ok(Type::record(out))
2472 }
2473 })();
2474 res.map_err(|err| with_span(&span, err))
2475}
2476
2477fn type_from_annotation_expr_vars(
2478 adts: &HashMap<Symbol, AdtDecl>,
2479 expr: &TypeExpr,
2480 vars: &mut HashMap<Symbol, TypeVar>,
2481 supply: &mut TypeVarSupply,
2482) -> Result<Type, TypeError> {
2483 let span = *expr.span();
2484 let res = (|| match expr {
2485 TypeExpr::Name(_, name) => {
2486 let name = normalize_type_name(&name.to_dotted_symbol());
2487 if let Some(arity) = annotation_type_arity(adts, &name) {
2488 Ok(Type::con(name, arity))
2489 } else if let Some(tv) = vars.get(&name) {
2490 Ok(Type::var(tv.clone()))
2491 } else {
2492 let is_upper = name
2493 .chars()
2494 .next()
2495 .map(|c| c.is_uppercase())
2496 .unwrap_or(false);
2497 if is_upper {
2498 return Err(TypeError::UnknownTypeName(name));
2499 }
2500 let tv = supply.fresh(Some(name.clone()));
2501 vars.insert(name.clone(), tv.clone());
2502 Ok(Type::var(tv))
2503 }
2504 }
2505 TypeExpr::App(_, fun, arg) => {
2506 let fty = type_from_annotation_expr_vars(adts, fun, vars, supply)?;
2507 let aty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2508 Ok(type_app_with_result_syntax(fty, aty))
2509 }
2510 TypeExpr::Fun(_, arg, ret) => {
2511 let arg_ty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2512 let ret_ty = type_from_annotation_expr_vars(adts, ret, vars, supply)?;
2513 Ok(Type::fun(arg_ty, ret_ty))
2514 }
2515 TypeExpr::Tuple(_, elems) => {
2516 let mut out = Vec::new();
2517 for elem in elems {
2518 out.push(type_from_annotation_expr_vars(adts, elem, vars, supply)?);
2519 }
2520 Ok(Type::tuple(out))
2521 }
2522 TypeExpr::Record(_, fields) => {
2523 let mut out = Vec::new();
2524 for (name, ty) in fields {
2525 out.push((
2526 name.clone(),
2527 type_from_annotation_expr_vars(adts, ty, vars, supply)?,
2528 ));
2529 }
2530 Ok(Type::record(out))
2531 }
2532 })();
2533 res.map_err(|err| with_span(&span, err))
2534}
2535
2536fn annotation_type_arity(adts: &HashMap<Symbol, AdtDecl>, name: &Symbol) -> Option<usize> {
2537 if let Some(adt) = adts.get(name) {
2538 return Some(adt.params.len());
2539 }
2540 BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2541}
2542
2543fn normalize_type_name(name: &Symbol) -> Symbol {
2544 if name.as_ref() == "str" {
2545 BuiltinTypeId::String.as_symbol()
2546 } else {
2547 name.clone()
2548 }
2549}
2550
2551fn type_app_with_result_syntax(fun: Type, arg: Type) -> Type {
2552 if let TypeKind::App(head, ok) = fun.as_ref()
2553 && matches!(
2554 head.as_ref(),
2555 TypeKind::Con(c)
2556 if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
2557 )
2558 {
2559 return Type::app(Type::app(head.clone(), arg), ok.clone());
2560 }
2561 Type::app(fun, arg)
2562}
2563
2564fn predicates_from_constraints(
2565 adts: &HashMap<Symbol, AdtDecl>,
2566 constraints: &[TypeConstraint],
2567 vars: &mut HashMap<Symbol, TypeVar>,
2568 supply: &mut TypeVarSupply,
2569) -> Result<Vec<Predicate>, TypeError> {
2570 let mut out = Vec::with_capacity(constraints.len());
2571 for constraint in constraints {
2572 let ty = type_from_annotation_expr_vars(adts, &constraint.typ, vars, supply)?;
2573 out.push(Predicate::new(constraint.class.as_ref(), ty));
2574 }
2575 Ok(out)
2576}
2577
2578#[derive(Clone, Debug, PartialEq, Eq)]
2579pub struct AdtConflict {
2580 pub name: Symbol,
2581 pub definitions: Vec<Type>,
2582}
2583
2584#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
2585#[error("conflicting ADT definitions: {conflicts:?}")]
2586pub struct CollectAdtsError {
2587 pub conflicts: Vec<AdtConflict>,
2588}
2589
2590pub fn collect_adts_in_types(types: Vec<Type>) -> Result<Vec<Type>, CollectAdtsError> {
2627 fn visit(
2628 typ: &Type,
2629 out: &mut Vec<Type>,
2630 seen: &mut HashSet<Type>,
2631 defs_by_name: &mut BTreeMap<Symbol, Vec<Type>>,
2632 ) {
2633 match typ.as_ref() {
2634 TypeKind::Var(_) => {}
2635 TypeKind::Con(tc) => {
2636 if tc.builtin_id.is_none() {
2638 let adt = Type::new(TypeKind::Con(tc.clone()));
2639 if seen.insert(adt.clone()) {
2640 out.push(adt.clone());
2641 }
2642 let defs = defs_by_name.entry(tc.name.clone()).or_default();
2643 if !defs.contains(&adt) {
2644 defs.push(adt);
2645 }
2646 }
2647 }
2648 TypeKind::App(fun, arg) => {
2649 visit(fun, out, seen, defs_by_name);
2650 visit(arg, out, seen, defs_by_name);
2651 }
2652 TypeKind::Fun(arg, ret) => {
2653 visit(arg, out, seen, defs_by_name);
2654 visit(ret, out, seen, defs_by_name);
2655 }
2656 TypeKind::Tuple(elems) => {
2657 for elem in elems {
2658 visit(elem, out, seen, defs_by_name);
2659 }
2660 }
2661 TypeKind::Record(fields) => {
2662 for (_name, field_ty) in fields {
2663 visit(field_ty, out, seen, defs_by_name);
2664 }
2665 }
2666 }
2667 }
2668
2669 let mut out = Vec::new();
2670 let mut seen = HashSet::new();
2671 let mut defs_by_name: BTreeMap<Symbol, Vec<Type>> = BTreeMap::new();
2672 for typ in &types {
2673 visit(typ, &mut out, &mut seen, &mut defs_by_name);
2674 }
2675
2676 let conflicts: Vec<AdtConflict> = defs_by_name
2677 .into_iter()
2678 .filter_map(|(name, definitions)| {
2679 (definitions.len() > 1).then_some(AdtConflict { name, definitions })
2680 })
2681 .collect();
2682 if !conflicts.is_empty() {
2683 return Err(CollectAdtsError { conflicts });
2684 }
2685
2686 Ok(out)
2687}