1use std::collections::{BTreeMap, HashMap, HashSet};
4use std::fmt::{self, Display, Formatter};
5use std::sync::Arc;
6
7use chrono::{DateTime, Utc};
8use rexlang_ast::expr::{
9 ClassDecl, ClassMethodSig, Decl, DeclareFnDecl, Expr, FnDecl, InstanceDecl, InstanceMethodImpl,
10 Pattern, Scope, Symbol, TypeConstraint, TypeDecl, TypeExpr, intern, sym,
11};
12use rexlang_lexer::span::Span;
13use rexlang_util::{GasMeter, OutOfGas};
14use rpds::HashTrieMapSync;
15use uuid::Uuid;
16
17use crate::prelude;
18
19#[path = "inference.rs"]
20pub mod inference;
21
22pub use inference::{infer, infer_typed, infer_typed_with_gas, infer_with_gas};
23
24pub type TypeVarId = usize;
25
26#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
27pub enum BuiltinTypeId {
28 U8,
29 U16,
30 U32,
31 U64,
32 I8,
33 I16,
34 I32,
35 I64,
36 F32,
37 F64,
38 Bool,
39 String,
40 Uuid,
41 DateTime,
42 List,
43 Array,
44 Dict,
45 Option,
46 Promise,
47 Result,
48}
49
50impl BuiltinTypeId {
51 pub fn as_symbol(self) -> Symbol {
52 sym(self.as_str())
53 }
54
55 pub fn as_str(self) -> &'static str {
56 match self {
57 Self::U8 => "u8",
58 Self::U16 => "u16",
59 Self::U32 => "u32",
60 Self::U64 => "u64",
61 Self::I8 => "i8",
62 Self::I16 => "i16",
63 Self::I32 => "i32",
64 Self::I64 => "i64",
65 Self::F32 => "f32",
66 Self::F64 => "f64",
67 Self::Bool => "bool",
68 Self::String => "string",
69 Self::Uuid => "uuid",
70 Self::DateTime => "datetime",
71 Self::List => "List",
72 Self::Array => "Array",
73 Self::Dict => "Dict",
74 Self::Option => "Option",
75 Self::Promise => "Promise",
76 Self::Result => "Result",
77 }
78 }
79
80 pub fn arity(self) -> usize {
81 match self {
82 Self::List | Self::Array | Self::Dict | Self::Option | Self::Promise => 1,
83 Self::Result => 2,
84 _ => 0,
85 }
86 }
87
88 pub fn from_symbol(name: &Symbol) -> Option<Self> {
89 Self::from_name(name.as_ref())
90 }
91
92 pub fn from_name(name: &str) -> Option<Self> {
93 match name {
94 "u8" => Some(Self::U8),
95 "u16" => Some(Self::U16),
96 "u32" => Some(Self::U32),
97 "u64" => Some(Self::U64),
98 "i8" => Some(Self::I8),
99 "i16" => Some(Self::I16),
100 "i32" => Some(Self::I32),
101 "i64" => Some(Self::I64),
102 "f32" => Some(Self::F32),
103 "f64" => Some(Self::F64),
104 "bool" => Some(Self::Bool),
105 "string" => Some(Self::String),
106 "uuid" => Some(Self::Uuid),
107 "datetime" => Some(Self::DateTime),
108 "List" => Some(Self::List),
109 "Array" => Some(Self::Array),
110 "Dict" => Some(Self::Dict),
111 "Option" => Some(Self::Option),
112 "Promise" => Some(Self::Promise),
113 "Result" => Some(Self::Result),
114 _ => None,
115 }
116 }
117}
118
119#[derive(Clone, Debug, Eq, Hash, PartialEq)]
120pub struct TypeVar {
121 pub id: TypeVarId,
122 pub name: Option<Symbol>,
123}
124
125impl TypeVar {
126 pub fn new(id: TypeVarId, name: impl Into<Option<Symbol>>) -> Self {
127 Self {
128 id,
129 name: name.into(),
130 }
131 }
132}
133
134#[derive(Clone, Debug, Eq, Hash, PartialEq)]
135pub struct TypeConst {
136 pub name: Symbol,
137 pub arity: usize,
138 pub builtin_id: Option<BuiltinTypeId>,
139}
140
141#[derive(Clone, Debug, PartialEq, Eq, Hash)]
142pub struct Type(Arc<TypeKind>);
143
144#[derive(Clone, Debug, PartialEq, Eq, Hash)]
145pub enum TypeKind {
146 Var(TypeVar),
147 Con(TypeConst),
148 App(Type, Type),
149 Fun(Type, Type),
150 Tuple(Vec<Type>),
151 Record(Vec<(Symbol, Type)>),
156}
157
158impl Type {
159 pub fn new(kind: TypeKind) -> Self {
160 Type(Arc::new(kind))
161 }
162
163 pub fn con(name: impl AsRef<str>, arity: usize) -> Self {
164 if let Some(id) = BuiltinTypeId::from_name(name.as_ref())
165 && id.arity() == arity
166 {
167 return Self::builtin(id);
168 }
169 Self::user_con(name, arity)
170 }
171
172 pub fn user_con(name: impl AsRef<str>, arity: usize) -> Self {
173 Type::new(TypeKind::Con(TypeConst {
174 name: intern(name.as_ref()),
175 arity,
176 builtin_id: None,
177 }))
178 }
179
180 pub fn builtin(id: BuiltinTypeId) -> Self {
181 Type::new(TypeKind::Con(TypeConst {
182 name: id.as_symbol(),
183 arity: id.arity(),
184 builtin_id: Some(id),
185 }))
186 }
187
188 pub fn var(tv: TypeVar) -> Self {
189 Type::new(TypeKind::Var(tv))
190 }
191
192 pub fn fun(a: Type, b: Type) -> Self {
193 Type::new(TypeKind::Fun(a, b))
194 }
195
196 pub fn app(f: Type, arg: Type) -> Self {
197 Type::new(TypeKind::App(f, arg))
198 }
199
200 pub fn tuple(elems: Vec<Type>) -> Self {
201 Type::new(TypeKind::Tuple(elems))
202 }
203
204 pub fn record(mut fields: Vec<(Symbol, Type)>) -> Self {
205 fields.sort_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
208 Type::new(TypeKind::Record(fields))
209 }
210
211 pub fn list(elem: Type) -> Type {
212 Type::app(Type::builtin(BuiltinTypeId::List), elem)
213 }
214
215 pub fn array(elem: Type) -> Type {
216 Type::app(Type::builtin(BuiltinTypeId::Array), elem)
217 }
218
219 pub fn dict(elem: Type) -> Type {
220 Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
221 }
222
223 pub fn option(elem: Type) -> Type {
224 Type::app(Type::builtin(BuiltinTypeId::Option), elem)
225 }
226
227 pub fn promise(elem: Type) -> Type {
228 Type::app(Type::builtin(BuiltinTypeId::Promise), elem)
229 }
230
231 pub fn result(ok: Type, err: Type) -> Type {
232 Type::app(Type::app(Type::builtin(BuiltinTypeId::Result), err), ok)
233 }
234
235 fn apply_with_change(&self, s: &Subst) -> (Type, bool) {
236 match self.as_ref() {
237 TypeKind::Var(tv) => match s.get(&tv.id) {
238 Some(ty) => (ty.clone(), true),
239 None => (self.clone(), false),
240 },
241 TypeKind::Con(_) => (self.clone(), false),
242 TypeKind::App(l, r) => {
243 let (l_new, l_changed) = l.apply_with_change(s);
244 let (r_new, r_changed) = r.apply_with_change(s);
245 if l_changed || r_changed {
246 (Type::app(l_new, r_new), true)
247 } else {
248 (self.clone(), false)
249 }
250 }
251 TypeKind::Fun(_, _) => {
252 let mut args = Vec::new();
255 let mut changed = false;
256 let mut cur: &Type = self;
257 while let TypeKind::Fun(a, b) = cur.as_ref() {
258 let (a_new, a_changed) = a.apply_with_change(s);
259 changed |= a_changed;
260 args.push(a_new);
261 cur = b;
262 }
263 let (ret_new, ret_changed) = cur.apply_with_change(s);
264 changed |= ret_changed;
265 if !changed {
266 return (self.clone(), false);
267 }
268 let mut out = ret_new;
269 for a_new in args.into_iter().rev() {
270 out = Type::fun(a_new, out);
271 }
272 (out, true)
273 }
274 TypeKind::Tuple(ts) => {
275 let mut changed = false;
276 let mut out = Vec::with_capacity(ts.len());
277 for t in ts {
278 let (t_new, t_changed) = t.apply_with_change(s);
279 changed |= t_changed;
280 out.push(t_new);
281 }
282 if changed {
283 (Type::new(TypeKind::Tuple(out)), true)
284 } else {
285 (self.clone(), false)
286 }
287 }
288 TypeKind::Record(fields) => {
289 let mut changed = false;
290 let mut out = Vec::with_capacity(fields.len());
291 for (k, v) in fields {
292 let (v_new, v_changed) = v.apply_with_change(s);
293 changed |= v_changed;
294 out.push((k.clone(), v_new));
295 }
296 if changed {
297 (Type::new(TypeKind::Record(out)), true)
298 } else {
299 (self.clone(), false)
300 }
301 }
302 }
303 }
304}
305
306impl AsRef<TypeKind> for Type {
307 fn as_ref(&self) -> &TypeKind {
308 self.0.as_ref()
309 }
310}
311
312impl std::ops::Deref for Type {
313 type Target = TypeKind;
314
315 fn deref(&self) -> &Self::Target {
316 &self.0
317 }
318}
319
320impl Display for Type {
321 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
322 match self.as_ref() {
323 TypeKind::Var(tv) => match &tv.name {
324 Some(name) => write!(f, "'{}", name),
325 None => write!(f, "t{}", tv.id),
326 },
327 TypeKind::Con(c) => write!(f, "{}", c.name),
328 TypeKind::App(l, r) => {
329 if let TypeKind::App(head, err) = l.as_ref()
335 && matches!(
336 head.as_ref(),
337 TypeKind::Con(c)
338 if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
339 )
340 {
341 return write!(f, "(Result {} {})", r, err);
342 }
343 write!(f, "({} {})", l, r)
344 }
345 TypeKind::Fun(a, b) => write!(f, "({} -> {})", a, b),
346 TypeKind::Tuple(elems) => {
347 write!(f, "(")?;
348 for (i, t) in elems.iter().enumerate() {
349 write!(f, "{}", t)?;
350 if i + 1 < elems.len() {
351 write!(f, ", ")?;
352 }
353 }
354 write!(f, ")")
355 }
356 TypeKind::Record(fields) => {
357 write!(f, "{{")?;
358 for (i, (name, ty)) in fields.iter().enumerate() {
359 write!(f, "{}: {}", name, ty)?;
360 if i + 1 < fields.len() {
361 write!(f, ", ")?;
362 }
363 }
364 write!(f, "}}")
365 }
366 }
367 }
368}
369
370#[derive(Clone, Debug, PartialEq, Eq, Hash)]
371pub struct Predicate {
372 pub class: Symbol,
373 pub typ: Type,
374}
375
376impl Predicate {
377 pub fn new(class: impl AsRef<str>, typ: Type) -> Self {
378 Self {
379 class: intern(class.as_ref()),
380 typ,
381 }
382 }
383}
384
385#[derive(Clone, Debug, PartialEq)]
386pub struct Scheme {
387 pub vars: Vec<TypeVar>,
388 pub preds: Vec<Predicate>,
389 pub typ: Type,
390}
391
392impl Scheme {
393 pub fn new(vars: Vec<TypeVar>, preds: Vec<Predicate>, typ: Type) -> Self {
394 Self { vars, preds, typ }
395 }
396}
397
398pub type Subst = HashTrieMapSync<TypeVarId, Type>;
399
400pub trait Types: Sized {
401 fn apply(&self, s: &Subst) -> Self;
402 fn ftv(&self) -> HashSet<TypeVarId>;
403}
404
405impl Types for Type {
406 fn apply(&self, s: &Subst) -> Self {
407 self.apply_with_change(s).0
408 }
409
410 fn ftv(&self) -> HashSet<TypeVarId> {
411 let mut out = HashSet::new();
412 let mut stack: Vec<&Type> = vec![self];
413 while let Some(t) = stack.pop() {
414 match t.as_ref() {
415 TypeKind::Var(tv) => {
416 out.insert(tv.id);
417 }
418 TypeKind::Con(_) => {}
419 TypeKind::App(l, r) => {
420 stack.push(l);
421 stack.push(r);
422 }
423 TypeKind::Fun(a, b) => {
424 stack.push(a);
425 stack.push(b);
426 }
427 TypeKind::Tuple(ts) => {
428 for t in ts {
429 stack.push(t);
430 }
431 }
432 TypeKind::Record(fields) => {
433 for (_, ty) in fields {
434 stack.push(ty);
435 }
436 }
437 }
438 }
439 out
440 }
441}
442
443impl Types for Predicate {
444 fn apply(&self, s: &Subst) -> Self {
445 Predicate {
446 class: self.class.clone(),
447 typ: self.typ.apply(s),
448 }
449 }
450
451 fn ftv(&self) -> HashSet<TypeVarId> {
452 self.typ.ftv()
453 }
454}
455
456impl Types for Scheme {
457 fn apply(&self, s: &Subst) -> Self {
458 let mut s_pruned = Subst::new_sync();
459 for (k, v) in s.iter() {
460 if !self.vars.iter().any(|var| var.id == *k) {
461 s_pruned = s_pruned.insert(*k, v.clone());
462 }
463 }
464 Scheme::new(
465 self.vars.clone(),
466 self.preds.iter().map(|p| p.apply(&s_pruned)).collect(),
467 self.typ.apply(&s_pruned),
468 )
469 }
470
471 fn ftv(&self) -> HashSet<TypeVarId> {
472 let mut ftv = self.typ.ftv();
473 for p in &self.preds {
474 ftv.extend(p.ftv());
475 }
476 for v in &self.vars {
477 ftv.remove(&v.id);
478 }
479 ftv
480 }
481}
482
483impl<T: Types> Types for Vec<T> {
484 fn apply(&self, s: &Subst) -> Self {
485 self.iter().map(|t| t.apply(s)).collect()
486 }
487
488 fn ftv(&self) -> HashSet<TypeVarId> {
489 self.iter().flat_map(Types::ftv).collect()
490 }
491}
492
493#[derive(Clone, Debug, PartialEq)]
494pub struct TypedExpr {
495 pub typ: Type,
496 pub kind: TypedExprKind,
497}
498
499impl TypedExpr {
500 pub fn new(typ: Type, kind: TypedExprKind) -> Self {
501 Self { typ, kind }
502 }
503
504 pub fn apply(&self, s: &Subst) -> Self {
505 match &self.kind {
506 TypedExprKind::Lam { .. } => {
507 let mut params: Vec<(Symbol, Type)> = Vec::new();
508 let mut cur = self;
509 while let TypedExprKind::Lam { param, body } = &cur.kind {
510 params.push((param.clone(), cur.typ.apply(s)));
511 cur = body.as_ref();
512 }
513 let mut out = cur.apply(s);
514 for (param, typ) in params.into_iter().rev() {
515 out = TypedExpr {
516 typ,
517 kind: TypedExprKind::Lam {
518 param,
519 body: Box::new(out),
520 },
521 };
522 }
523 return out;
524 }
525 TypedExprKind::App(..) => {
526 let mut apps: Vec<(Type, &TypedExpr)> = Vec::new();
527 let mut cur = self;
528 while let TypedExprKind::App(f, x) = &cur.kind {
529 apps.push((cur.typ.apply(s), x.as_ref()));
530 cur = f.as_ref();
531 }
532 let mut out = cur.apply(s);
533 for (typ, arg) in apps.into_iter().rev() {
534 out = TypedExpr {
535 typ,
536 kind: TypedExprKind::App(Box::new(out), Box::new(arg.apply(s))),
537 };
538 }
539 return out;
540 }
541 _ => {}
542 }
543
544 let typ = self.typ.apply(s);
545 let kind = match &self.kind {
546 TypedExprKind::Bool(v) => TypedExprKind::Bool(*v),
547 TypedExprKind::Uint(v) => TypedExprKind::Uint(*v),
548 TypedExprKind::Int(v) => TypedExprKind::Int(*v),
549 TypedExprKind::Float(v) => TypedExprKind::Float(*v),
550 TypedExprKind::String(v) => TypedExprKind::String(v.clone()),
551 TypedExprKind::Uuid(v) => TypedExprKind::Uuid(*v),
552 TypedExprKind::DateTime(v) => TypedExprKind::DateTime(*v),
553 TypedExprKind::Hole => TypedExprKind::Hole,
554 TypedExprKind::Tuple(elems) => {
555 TypedExprKind::Tuple(elems.iter().map(|e| e.apply(s)).collect())
556 }
557 TypedExprKind::List(elems) => {
558 TypedExprKind::List(elems.iter().map(|e| e.apply(s)).collect())
559 }
560 TypedExprKind::Dict(kvs) => {
561 let mut out = BTreeMap::new();
562 for (k, v) in kvs {
563 out.insert(k.clone(), v.apply(s));
564 }
565 TypedExprKind::Dict(out)
566 }
567 TypedExprKind::RecordUpdate { base, updates } => {
568 let mut out = BTreeMap::new();
569 for (k, v) in updates {
570 out.insert(k.clone(), v.apply(s));
571 }
572 TypedExprKind::RecordUpdate {
573 base: Box::new(base.apply(s)),
574 updates: out,
575 }
576 }
577 TypedExprKind::Var { name, overloads } => TypedExprKind::Var {
578 name: name.clone(),
579 overloads: overloads.iter().map(|t| t.apply(s)).collect(),
580 },
581 TypedExprKind::App(f, x) => {
582 TypedExprKind::App(Box::new(f.apply(s)), Box::new(x.apply(s)))
583 }
584 TypedExprKind::Project { expr, field } => TypedExprKind::Project {
585 expr: Box::new(expr.apply(s)),
586 field: field.clone(),
587 },
588 TypedExprKind::Lam { param, body } => TypedExprKind::Lam {
589 param: param.clone(),
590 body: Box::new(body.apply(s)),
591 },
592 TypedExprKind::Let { name, def, body } => TypedExprKind::Let {
593 name: name.clone(),
594 def: Box::new(def.apply(s)),
595 body: Box::new(body.apply(s)),
596 },
597 TypedExprKind::LetRec { bindings, body } => TypedExprKind::LetRec {
598 bindings: bindings
599 .iter()
600 .map(|(name, def)| (name.clone(), def.apply(s)))
601 .collect(),
602 body: Box::new(body.apply(s)),
603 },
604 TypedExprKind::Ite {
605 cond,
606 then_expr,
607 else_expr,
608 } => TypedExprKind::Ite {
609 cond: Box::new(cond.apply(s)),
610 then_expr: Box::new(then_expr.apply(s)),
611 else_expr: Box::new(else_expr.apply(s)),
612 },
613 TypedExprKind::Match { scrutinee, arms } => TypedExprKind::Match {
614 scrutinee: Box::new(scrutinee.apply(s)),
615 arms: arms.iter().map(|(p, e)| (p.clone(), e.apply(s))).collect(),
616 },
617 };
618 TypedExpr { typ, kind }
619 }
620}
621
622#[derive(Clone, Debug, PartialEq)]
623pub enum TypedExprKind {
624 Bool(bool),
625 Uint(u64),
626 Int(i64),
627 Float(f64),
628 String(String),
629 Uuid(Uuid),
630 DateTime(DateTime<Utc>),
631 Hole,
632 Tuple(Vec<TypedExpr>),
633 List(Vec<TypedExpr>),
634 Dict(BTreeMap<Symbol, TypedExpr>),
635 RecordUpdate {
636 base: Box<TypedExpr>,
637 updates: BTreeMap<Symbol, TypedExpr>,
638 },
639 Var {
640 name: Symbol,
641 overloads: Vec<Type>,
642 },
643 App(Box<TypedExpr>, Box<TypedExpr>),
644 Project {
645 expr: Box<TypedExpr>,
646 field: Symbol,
647 },
648 Lam {
649 param: Symbol,
650 body: Box<TypedExpr>,
651 },
652 Let {
653 name: Symbol,
654 def: Box<TypedExpr>,
655 body: Box<TypedExpr>,
656 },
657 LetRec {
658 bindings: Vec<(Symbol, TypedExpr)>,
659 body: Box<TypedExpr>,
660 },
661 Ite {
662 cond: Box<TypedExpr>,
663 then_expr: Box<TypedExpr>,
664 else_expr: Box<TypedExpr>,
665 },
666 Match {
667 scrutinee: Box<TypedExpr>,
668 arms: Vec<(Pattern, TypedExpr)>,
669 },
670}
671
672pub fn compose_subst(a: Subst, b: Subst) -> Subst {
677 if subst_is_empty(&a) {
678 return b;
679 }
680 if subst_is_empty(&b) {
681 return a;
682 }
683 let mut res = Subst::new_sync();
684 for (k, v) in b.iter() {
685 res = res.insert(*k, v.apply(&a));
686 }
687 for (k, v) in a.iter() {
688 res = res.insert(*k, v.clone());
689 }
690 res
691}
692
693fn subst_is_empty(s: &Subst) -> bool {
694 s.iter().next().is_none()
695}
696
697#[derive(Debug, thiserror::Error, PartialEq, Eq)]
698pub enum TypeError {
699 #[error("types do not unify: {0} vs {1}")]
700 Unification(String, String),
701 #[error("occurs check failed for {0} in {1}")]
702 Occurs(TypeVarId, String),
703 #[error("unknown class {0}")]
704 UnknownClass(Symbol),
705 #[error("no instance for {0} {1}")]
706 NoInstance(Symbol, String),
707 #[error("unknown type {0}")]
708 UnknownTypeName(Symbol),
709 #[error("cannot redefine reserved builtin type `{0}`")]
710 ReservedTypeName(Symbol),
711 #[error("duplicate value definition `{0}`")]
712 DuplicateValue(Symbol),
713 #[error("duplicate class definition `{0}`")]
714 DuplicateClass(Symbol),
715 #[error("class `{class}` must have at least one type parameter (got {got})")]
716 InvalidClassArity { class: Symbol, got: usize },
717 #[error("duplicate class method `{0}`")]
718 DuplicateClassMethod(Symbol),
719 #[error("unknown method `{method}` in instance of class `{class}`")]
720 UnknownInstanceMethod { class: Symbol, method: Symbol },
721 #[error("missing implementation of `{method}` for instance of class `{class}`")]
722 MissingInstanceMethod { class: Symbol, method: Symbol },
723 #[error(
724 "instance method `{method}` requires constraint {class} {typ}, but it is not in the instance context"
725 )]
726 MissingInstanceConstraint {
727 method: Symbol,
728 class: Symbol,
729 typ: String,
730 },
731 #[error("unbound variable {0}")]
732 UnknownVar(Symbol),
733 #[error("ambiguous overload for {0}")]
734 AmbiguousOverload(Symbol),
735 #[error("ambiguous type variable(s) {vars:?} in constraints: {constraints}")]
736 AmbiguousTypeVars {
737 vars: Vec<TypeVarId>,
738 constraints: String,
739 },
740 #[error(
741 "kind mismatch for class `{class}`: expected {expected} type argument(s) remaining, got {got} for {typ}"
742 )]
743 KindMismatch {
744 class: Symbol,
745 expected: usize,
746 got: usize,
747 typ: String,
748 },
749 #[error("missing type class constraint(s): {constraints}")]
750 MissingConstraints { constraints: String },
751 #[error("unsupported expression {0}")]
752 UnsupportedExpr(&'static str),
753 #[error("unknown field `{field}` on {typ}")]
754 UnknownField { field: Symbol, typ: String },
755 #[error("field `{field}` is not definitely available on {typ}")]
756 FieldNotKnown { field: Symbol, typ: String },
757 #[error("non-exhaustive match for {typ}: missing {missing:?}")]
758 NonExhaustiveMatch { typ: String, missing: Vec<Symbol> },
759 #[error("at {span}: {error}")]
760 Spanned { span: Span, error: Box<TypeError> },
761 #[error("internal error: {0}")]
762 Internal(String),
763 #[error("{0}")]
764 OutOfGas(#[from] OutOfGas),
765}
766
767fn with_span(span: &Span, err: TypeError) -> TypeError {
768 match err {
769 TypeError::Spanned { .. } => err,
770 other => TypeError::Spanned {
771 span: *span,
772 error: Box::new(other),
773 },
774 }
775}
776
777fn format_constraints_referencing_vars(preds: &[Predicate], vars: &[TypeVarId]) -> String {
778 if vars.is_empty() {
779 return String::new();
780 }
781 let var_set: HashSet<TypeVarId> = vars.iter().copied().collect();
782 let mut parts = Vec::new();
783 for pred in preds {
784 let ftv = pred.ftv();
785 if ftv.iter().any(|v| var_set.contains(v)) {
786 parts.push(format!("{} {}", pred.class, pred.typ));
787 }
788 }
789 if parts.is_empty() {
790 for pred in preds {
792 parts.push(format!("{} {}", pred.class, pred.typ));
793 }
794 }
795 parts.join(", ")
796}
797
798fn reject_ambiguous_scheme(scheme: &Scheme) -> Result<(), TypeError> {
799 let quantified: HashSet<TypeVarId> = scheme.vars.iter().map(|v| v.id).collect();
803 if quantified.is_empty() {
804 return Ok(());
805 }
806
807 let typ_ftv = scheme.typ.ftv();
808 let mut vars = HashSet::new();
809 for pred in &scheme.preds {
810 let TypeKind::Var(tv) = pred.typ.as_ref() else {
811 continue;
812 };
813 if quantified.contains(&tv.id) && !typ_ftv.contains(&tv.id) {
814 vars.insert(tv.id);
815 }
816 }
817
818 if vars.is_empty() {
819 return Ok(());
820 }
821 let mut vars: Vec<TypeVarId> = vars.into_iter().collect();
822 vars.sort_unstable();
823 let constraints = format_constraints_referencing_vars(&scheme.preds, &vars);
824 Err(TypeError::AmbiguousTypeVars { vars, constraints })
825}
826
827fn scheme_compatible(existing: &Scheme, declared: &Scheme) -> bool {
828 let s = match unify(&existing.typ, &declared.typ) {
829 Ok(s) => s,
830 Err(_) => return false,
831 };
832
833 let existing_preds = existing.preds.apply(&s);
834 let declared_preds = declared.preds.apply(&s);
835
836 let mut lhs: Vec<(Symbol, String)> = existing_preds
837 .iter()
838 .map(|p| (p.class.clone(), p.typ.to_string()))
839 .collect();
840 let mut rhs: Vec<(Symbol, String)> = declared_preds
841 .iter()
842 .map(|p| (p.class.clone(), p.typ.to_string()))
843 .collect();
844 lhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
845 rhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
846 lhs == rhs
847}
848
849#[derive(Debug)]
850struct Unifier<'g> {
851 subs: Vec<Option<Type>>,
858 gas: Option<&'g mut GasMeter>,
859 max_infer_depth: Option<usize>,
860 infer_depth: usize,
861}
862
863#[derive(Clone, Copy, Debug)]
864pub struct TypeSystemLimits {
865 pub max_infer_depth: Option<usize>,
866}
867
868impl TypeSystemLimits {
869 pub fn unlimited() -> Self {
870 Self {
871 max_infer_depth: None,
872 }
873 }
874
875 pub fn safe_defaults() -> Self {
876 Self {
877 max_infer_depth: Some(4096),
878 }
879 }
880}
881
882impl Default for TypeSystemLimits {
883 fn default() -> Self {
884 Self::safe_defaults()
885 }
886}
887
888fn superclass_closure(class_env: &ClassEnv, given: &[Predicate]) -> Vec<Predicate> {
889 let mut closure: Vec<Predicate> = given.to_vec();
890 let mut i = 0;
891 while i < closure.len() {
892 let p = closure[i].clone();
893 for sup in class_env.supers_of(&p.class) {
894 closure.push(Predicate::new(sup, p.typ.clone()));
895 }
896 i += 1;
897 }
898 closure
899}
900
901fn check_non_ground_predicates_declared(
902 class_env: &ClassEnv,
903 declared: &[Predicate],
904 inferred: &[Predicate],
905) -> Result<(), TypeError> {
906 let closure = superclass_closure(class_env, declared);
910 let closure_keys: HashSet<String> = closure
911 .iter()
912 .map(|p| format!("{} {}", p.class, p.typ))
913 .collect();
914 let mut missing = Vec::new();
915 for pred in inferred {
916 if pred.typ.ftv().is_empty() {
917 continue;
918 }
919 let key = format!("{} {}", pred.class, pred.typ);
920 if !closure_keys.contains(&key) {
921 missing.push(key);
922 }
923 }
924
925 missing.sort();
926 missing.dedup();
927 if missing.is_empty() {
928 return Ok(());
929 }
930 Err(TypeError::MissingConstraints {
931 constraints: missing.join(", "),
932 })
933}
934
935fn type_term_remaining_arity(ty: &Type) -> Option<usize> {
936 match ty.as_ref() {
937 TypeKind::Var(_) => None,
938 TypeKind::Con(tc) => Some(tc.arity),
939 TypeKind::App(l, _) => {
940 let a = type_term_remaining_arity(l)?;
941 Some(a.saturating_sub(1))
942 }
943 TypeKind::Fun(..) | TypeKind::Tuple(..) | TypeKind::Record(..) => Some(0),
944 }
945}
946
947fn max_head_app_arity_for_var(ty: &Type, var_id: TypeVarId) -> usize {
948 let mut max_arity = 0usize;
949 let mut stack: Vec<&Type> = vec![ty];
950 while let Some(t) = stack.pop() {
951 match t.as_ref() {
952 TypeKind::Var(_) | TypeKind::Con(_) => {}
953 TypeKind::App(l, r) => {
954 let mut head = t;
956 let mut args = 0usize;
957 while let TypeKind::App(left, _) = head.as_ref() {
958 args += 1;
959 head = left;
960 }
961 if let TypeKind::Var(tv) = head.as_ref()
962 && tv.id == var_id
963 {
964 max_arity = max_arity.max(args);
965 }
966 stack.push(l);
967 stack.push(r);
968 }
969 TypeKind::Fun(a, b) => {
970 stack.push(a);
971 stack.push(b);
972 }
973 TypeKind::Tuple(ts) => {
974 for t in ts {
975 stack.push(t);
976 }
977 }
978 TypeKind::Record(fields) => {
979 for (_, t) in fields {
980 stack.push(t);
981 }
982 }
983 }
984 }
985 max_arity
986}
987
988impl<'g> Unifier<'g> {
989 fn new(max_infer_depth: Option<usize>) -> Self {
990 Self {
991 subs: Vec::new(),
992 gas: None,
993 max_infer_depth,
994 infer_depth: 0,
995 }
996 }
997
998 fn with_gas(gas: &'g mut GasMeter, max_infer_depth: Option<usize>) -> Self {
999 Self {
1000 subs: Vec::new(),
1001 gas: Some(gas),
1002 max_infer_depth,
1003 infer_depth: 0,
1004 }
1005 }
1006
1007 fn with_infer_depth<T>(
1008 &mut self,
1009 span: Span,
1010 f: impl FnOnce(&mut Self) -> Result<T, TypeError>,
1011 ) -> Result<T, TypeError> {
1012 if let Some(max) = self.max_infer_depth
1013 && self.infer_depth >= max
1014 {
1015 return Err(TypeError::Spanned {
1016 span,
1017 error: Box::new(TypeError::Internal(format!(
1018 "maximum inference depth exceeded (max {max})"
1019 ))),
1020 });
1021 }
1022 self.infer_depth += 1;
1023 let res = f(self);
1024 self.infer_depth = self.infer_depth.saturating_sub(1);
1025 res
1026 }
1027
1028 fn charge_infer_node(&mut self) -> Result<(), TypeError> {
1029 let Some(gas) = self.gas.as_mut() else {
1030 return Ok(());
1031 };
1032 let cost = gas.costs.infer_node;
1033 gas.charge(cost)?;
1034 Ok(())
1035 }
1036
1037 fn charge_unify_step(&mut self) -> Result<(), TypeError> {
1038 let Some(gas) = self.gas.as_mut() else {
1039 return Ok(());
1040 };
1041 let cost = gas.costs.unify_step;
1042 gas.charge(cost)?;
1043 Ok(())
1044 }
1045
1046 fn bind_var(&mut self, id: TypeVarId, ty: Type) {
1047 if id >= self.subs.len() {
1048 self.subs.resize(id + 1, None);
1049 }
1050 self.subs[id] = Some(ty);
1051 }
1052
1053 fn prune(&mut self, ty: &Type) -> Type {
1054 match ty.as_ref() {
1055 TypeKind::Var(tv) => {
1056 let bound = self.subs.get(tv.id).and_then(|t| t.clone());
1057 match bound {
1058 Some(bound) => {
1059 let pruned = self.prune(&bound);
1060 self.bind_var(tv.id, pruned.clone());
1061 pruned
1062 }
1063 None => ty.clone(),
1064 }
1065 }
1066 TypeKind::Con(_) => ty.clone(),
1067 TypeKind::App(l, r) => {
1068 let l = self.prune(l);
1069 let r = self.prune(r);
1070 Type::app(l, r)
1071 }
1072 TypeKind::Fun(a, b) => {
1073 let a = self.prune(a);
1074 let b = self.prune(b);
1075 Type::fun(a, b)
1076 }
1077 TypeKind::Tuple(ts) => {
1078 Type::new(TypeKind::Tuple(ts.iter().map(|t| self.prune(t)).collect()))
1079 }
1080 TypeKind::Record(fields) => Type::new(TypeKind::Record(
1081 fields
1082 .iter()
1083 .map(|(name, ty)| (name.clone(), self.prune(ty)))
1084 .collect(),
1085 )),
1086 }
1087 }
1088
1089 fn apply_type(&mut self, ty: &Type) -> Type {
1090 self.prune(ty)
1091 }
1092
1093 fn occurs(&mut self, id: TypeVarId, ty: &Type) -> bool {
1094 match self.prune(ty).as_ref() {
1095 TypeKind::Var(tv) => tv.id == id,
1096 TypeKind::Con(_) => false,
1097 TypeKind::App(l, r) => self.occurs(id, l) || self.occurs(id, r),
1098 TypeKind::Fun(a, b) => self.occurs(id, a) || self.occurs(id, b),
1099 TypeKind::Tuple(ts) => ts.iter().any(|t| self.occurs(id, t)),
1100 TypeKind::Record(fields) => fields.iter().any(|(_, ty)| self.occurs(id, ty)),
1101 }
1102 }
1103
1104 fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
1105 self.charge_unify_step()?;
1106 let t1 = self.prune(t1);
1107 let t2 = self.prune(t2);
1108 match (t1.as_ref(), t2.as_ref()) {
1109 (TypeKind::Var(a), TypeKind::Var(b)) if a.id == b.id => Ok(()),
1110 (TypeKind::Var(tv), other) | (other, TypeKind::Var(tv)) => {
1111 if self.occurs(tv.id, &Type::new(other.clone())) {
1112 Err(TypeError::Occurs(
1113 tv.id,
1114 Type::new(other.clone()).to_string(),
1115 ))
1116 } else {
1117 self.bind_var(tv.id, Type::new(other.clone()));
1118 Ok(())
1119 }
1120 }
1121 (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(()),
1122 (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1123 self.unify(l1, l2)?;
1124 self.unify(r1, r2)
1125 }
1126 (TypeKind::Fun(a1, b1), TypeKind::Fun(a2, b2)) => {
1127 self.unify(a1, a2)?;
1128 self.unify(b1, b2)
1129 }
1130 (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1131 if ts1.len() != ts2.len() {
1132 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1133 }
1134 for (a, b) in ts1.iter().zip(ts2.iter()) {
1135 self.unify(a, b)?;
1136 }
1137 Ok(())
1138 }
1139 (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1140 if f1.len() != f2.len() {
1141 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1142 }
1143 for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1144 if n1 != n2 {
1145 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1146 }
1147 self.unify(t1, t2)?;
1148 }
1149 Ok(())
1150 }
1151 (TypeKind::Record(fields), TypeKind::App(head, arg))
1152 | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1153 TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1154 let elem_ty = record_elem_type_unifier(fields, self)?;
1155 self.unify(arg, &elem_ty)
1156 }
1157 TypeKind::Var(tv) => {
1158 self.unify(
1159 &Type::new(TypeKind::Var(tv.clone())),
1160 &Type::builtin(BuiltinTypeId::Dict),
1161 )?;
1162 let elem_ty = record_elem_type_unifier(fields, self)?;
1163 self.unify(arg, &elem_ty)
1164 }
1165 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1166 },
1167 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1168 }
1169 }
1170
1171 fn into_subst(mut self) -> Subst {
1172 let mut out = Subst::new_sync();
1173 for id in 0..self.subs.len() {
1174 if let Some(ty) = self.subs[id].clone() {
1175 let pruned = self.prune(&ty);
1176 out = out.insert(id, pruned);
1177 }
1178 }
1179 out
1180 }
1181}
1182
1183fn record_elem_type_unifier(
1184 fields: &[(Symbol, Type)],
1185 unifier: &mut Unifier<'_>,
1186) -> Result<Type, TypeError> {
1187 let mut iter = fields.iter();
1188 let first = match iter.next() {
1189 Some((_, ty)) => ty.clone(),
1190 None => return Err(TypeError::UnsupportedExpr("empty record")),
1191 };
1192 for (_, ty) in iter {
1193 unifier.unify(&first, ty)?;
1194 }
1195 Ok(unifier.apply_type(&first))
1196}
1197
1198fn bind(tv: &TypeVar, t: &Type) -> Result<Subst, TypeError> {
1199 if let TypeKind::Var(var) = t.as_ref()
1200 && var.id == tv.id
1201 {
1202 return Ok(Subst::new_sync());
1203 }
1204 if t.ftv().contains(&tv.id) {
1205 Err(TypeError::Occurs(tv.id, t.to_string()))
1206 } else {
1207 Ok(Subst::new_sync().insert(tv.id, t.clone()))
1208 }
1209}
1210
1211fn record_elem_type(fields: &[(Symbol, Type)]) -> Result<(Subst, Type), TypeError> {
1212 let mut iter = fields.iter();
1213 let first = match iter.next() {
1214 Some((_, ty)) => ty.clone(),
1215 None => return Err(TypeError::UnsupportedExpr("empty record")),
1216 };
1217 let mut subst = Subst::new_sync();
1218 let mut current = first;
1219 for (_, ty) in iter {
1220 let s_next = unify(¤t.apply(&subst), &ty.apply(&subst))?;
1221 subst = compose_subst(s_next, subst);
1222 current = current.apply(&subst);
1223 }
1224 Ok((subst.clone(), current.apply(&subst)))
1225}
1226
1227pub fn unify(t1: &Type, t2: &Type) -> Result<Subst, TypeError> {
1234 match (t1.as_ref(), t2.as_ref()) {
1235 (TypeKind::Fun(l1, r1), TypeKind::Fun(l2, r2)) => {
1236 let s1 = unify(l1, l2)?;
1237 let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1238 Ok(compose_subst(s2, s1))
1239 }
1240 (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1241 if f1.len() != f2.len() {
1242 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1243 }
1244 let mut subst = Subst::new_sync();
1245 for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1246 if n1 != n2 {
1247 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1248 }
1249 let s_next = unify(&t1.apply(&subst), &t2.apply(&subst))?;
1250 subst = compose_subst(s_next, subst);
1251 }
1252 Ok(subst)
1253 }
1254 (TypeKind::Record(fields), TypeKind::App(head, arg))
1255 | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1256 TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1257 let (s_fields, elem_ty) = record_elem_type(fields)?;
1258 let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1259 Ok(compose_subst(s_arg, s_fields))
1260 }
1261 TypeKind::Var(tv) => {
1262 let s_head = bind(tv, &Type::builtin(BuiltinTypeId::Dict))?;
1263 let arg = arg.apply(&s_head);
1264 let (s_fields, elem_ty) = record_elem_type(fields)?;
1265 let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1266 Ok(compose_subst(s_arg, compose_subst(s_fields, s_head)))
1267 }
1268 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1269 },
1270 (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1271 let s1 = unify(l1, l2)?;
1272 let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1273 Ok(compose_subst(s2, s1))
1274 }
1275 (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1276 if ts1.len() != ts2.len() {
1277 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1278 }
1279 let mut s = Subst::new_sync();
1280 for (a, b) in ts1.iter().zip(ts2.iter()) {
1281 let s_next = unify(&a.apply(&s), &b.apply(&s))?;
1282 s = compose_subst(s_next, s);
1283 }
1284 Ok(s)
1285 }
1286 (TypeKind::Var(tv), t) | (t, TypeKind::Var(tv)) => bind(tv, &Type::new(t.clone())),
1287 (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(Subst::new_sync()),
1288 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1289 }
1290}
1291
1292#[derive(Default, Debug, Clone)]
1293pub struct TypeEnv {
1294 pub values: HashTrieMapSync<Symbol, Vec<Scheme>>,
1295}
1296
1297impl TypeEnv {
1298 pub fn new() -> Self {
1299 Self {
1300 values: HashTrieMapSync::new_sync(),
1301 }
1302 }
1303
1304 pub fn extend(&mut self, name: Symbol, scheme: Scheme) {
1305 self.values = self.values.insert(name, vec![scheme]);
1306 }
1307
1308 pub fn extend_overload(&mut self, name: Symbol, scheme: Scheme) {
1309 let mut schemes = self.values.get(&name).cloned().unwrap_or_default();
1310 schemes.push(scheme);
1311 self.values = self.values.insert(name, schemes);
1312 }
1313
1314 pub fn remove(&mut self, name: &Symbol) {
1315 self.values = self.values.remove(name);
1316 }
1317
1318 pub fn lookup(&self, name: &Symbol) -> Option<&[Scheme]> {
1319 self.values.get(name).map(|schemes| schemes.as_slice())
1320 }
1321}
1322
1323impl Types for TypeEnv {
1324 fn apply(&self, s: &Subst) -> Self {
1325 let mut values = HashTrieMapSync::new_sync();
1326 for (k, v) in self.values.iter() {
1327 let updated = v
1328 .iter()
1329 .map(|scheme| {
1330 if scheme.vars.is_empty() && !subst_is_empty(s) {
1333 scheme.apply(s)
1334 } else {
1335 scheme.clone()
1336 }
1337 })
1338 .collect();
1339 values = values.insert(k.clone(), updated);
1340 }
1341 TypeEnv { values }
1342 }
1343
1344 fn ftv(&self) -> HashSet<TypeVarId> {
1345 self.values
1346 .iter()
1347 .flat_map(|(_, schemes)| schemes.iter().flat_map(Types::ftv))
1348 .collect()
1349 }
1350}
1351
1352#[derive(Default, Debug, Clone)]
1353pub struct TypeVarSupply {
1354 counter: TypeVarId,
1355}
1356
1357impl TypeVarSupply {
1358 pub fn new() -> Self {
1359 Self { counter: 0 }
1360 }
1361
1362 pub fn fresh(&mut self, name_hint: impl Into<Option<Symbol>>) -> TypeVar {
1363 let tv = TypeVar::new(self.counter, name_hint.into());
1364 self.counter += 1;
1365 tv
1366 }
1367}
1368
1369fn is_integral_literal_expr(expr: &Expr) -> bool {
1370 matches!(expr, Expr::Int(..) | Expr::Uint(..))
1371}
1372
1373pub fn generalize(env: &TypeEnv, preds: Vec<Predicate>, typ: Type) -> Scheme {
1376 let mut vars: Vec<TypeVar> = typ
1377 .ftv()
1378 .union(&preds.ftv())
1379 .copied()
1380 .collect::<HashSet<_>>()
1381 .difference(&env.ftv())
1382 .cloned()
1383 .map(|id| TypeVar::new(id, None))
1384 .collect();
1385 vars.sort_by_key(|v| v.id);
1386 Scheme::new(vars, preds, typ)
1387}
1388
1389pub fn instantiate(scheme: &Scheme, supply: &mut TypeVarSupply) -> (Vec<Predicate>, Type) {
1390 let mut subst = Subst::new_sync();
1393 for v in &scheme.vars {
1394 subst = subst.insert(v.id, Type::var(supply.fresh(v.name.clone())));
1395 }
1396 (scheme.preds.apply(&subst), scheme.typ.apply(&subst))
1397}
1398
1399#[derive(Clone, Debug)]
1401pub struct AdtParam {
1402 pub name: Symbol,
1403 pub var: TypeVar,
1404}
1405
1406#[derive(Clone, Debug)]
1408pub struct AdtVariant {
1409 pub name: Symbol,
1410 pub args: Vec<Type>,
1411}
1412
1413#[derive(Clone, Debug)]
1419pub struct AdtDecl {
1420 pub name: Symbol,
1421 pub params: Vec<AdtParam>,
1422 pub variants: Vec<AdtVariant>,
1423}
1424
1425impl AdtDecl {
1426 pub fn new(name: &Symbol, param_names: &[Symbol], supply: &mut TypeVarSupply) -> Self {
1427 let params = param_names
1428 .iter()
1429 .map(|p| AdtParam {
1430 name: p.clone(),
1431 var: supply.fresh(Some(p.clone())),
1432 })
1433 .collect();
1434 Self {
1435 name: name.clone(),
1436 params,
1437 variants: Vec::new(),
1438 }
1439 }
1440
1441 pub fn param_type(&self, name: &Symbol) -> Option<Type> {
1442 self.params
1443 .iter()
1444 .find(|p| &p.name == name)
1445 .map(|p| Type::var(p.var.clone()))
1446 }
1447
1448 pub fn add_variant(&mut self, name: Symbol, args: Vec<Type>) {
1449 self.variants.push(AdtVariant { name, args });
1450 }
1451
1452 pub fn result_type(&self) -> Type {
1453 let mut ty = Type::con(&self.name, self.params.len());
1454 for param in &self.params {
1455 ty = Type::app(ty, Type::var(param.var.clone()));
1456 }
1457 ty
1458 }
1459
1460 pub fn constructor_schemes(&self) -> Vec<(Symbol, Scheme)> {
1463 let result_ty = self.result_type();
1464 let vars: Vec<TypeVar> = self.params.iter().map(|p| p.var.clone()).collect();
1465 let mut out = Vec::new();
1466 for variant in &self.variants {
1467 let mut typ = result_ty.clone();
1468 for arg in variant.args.iter().rev() {
1469 typ = Type::fun(arg.clone(), typ);
1470 }
1471 out.push((variant.name.clone(), Scheme::new(vars.clone(), vec![], typ)));
1472 }
1473 out
1474 }
1475}
1476
1477#[derive(Clone, Debug)]
1478pub struct Class {
1479 pub supers: Vec<Symbol>,
1480}
1481
1482impl Class {
1483 pub fn new(supers: Vec<Symbol>) -> Self {
1484 Self { supers }
1485 }
1486}
1487
1488#[derive(Clone, Debug)]
1489pub struct Instance {
1490 pub context: Vec<Predicate>,
1491 pub head: Predicate,
1492}
1493
1494impl Instance {
1495 pub fn new(context: Vec<Predicate>, head: Predicate) -> Self {
1496 Self { context, head }
1497 }
1498}
1499
1500#[derive(Default, Debug, Clone)]
1501pub struct ClassEnv {
1502 pub classes: HashMap<Symbol, Class>,
1503 pub instances: HashMap<Symbol, Vec<Instance>>,
1504}
1505
1506impl ClassEnv {
1507 pub fn new() -> Self {
1508 Self {
1509 classes: HashMap::new(),
1510 instances: HashMap::new(),
1511 }
1512 }
1513
1514 pub fn add_class(&mut self, name: Symbol, supers: Vec<Symbol>) {
1515 self.classes.insert(name, Class::new(supers));
1516 }
1517
1518 pub fn add_instance(&mut self, class: Symbol, inst: Instance) {
1519 self.instances.entry(class).or_default().push(inst);
1520 }
1521
1522 pub fn supers_of(&self, class: &Symbol) -> Vec<Symbol> {
1523 self.classes
1524 .get(class)
1525 .map(|c| c.supers.clone())
1526 .unwrap_or_default()
1527 }
1528}
1529
1530pub fn entails(
1531 class_env: &ClassEnv,
1532 given: &[Predicate],
1533 pred: &Predicate,
1534) -> Result<bool, TypeError> {
1535 let mut closure: Vec<Predicate> = given.to_vec();
1537 let mut i = 0;
1538 while i < closure.len() {
1539 let p = closure[i].clone();
1540 for sup in class_env.supers_of(&p.class) {
1541 closure.push(Predicate::new(sup, p.typ.clone()));
1542 }
1543 i += 1;
1544 }
1545
1546 if closure
1547 .iter()
1548 .any(|p| p.class == pred.class && p.typ == pred.typ)
1549 {
1550 return Ok(true);
1551 }
1552
1553 if !class_env.classes.contains_key(&pred.class) {
1554 return Err(TypeError::UnknownClass(pred.class.clone()));
1555 }
1556
1557 if let Some(instances) = class_env.instances.get(&pred.class) {
1558 for inst in instances {
1559 if let Ok(s) = unify(&inst.head.typ, &pred.typ) {
1560 let ctx = inst.context.apply(&s);
1561 if ctx
1562 .iter()
1563 .all(|c| entails(class_env, &closure, c).unwrap_or(false))
1564 {
1565 return Ok(true);
1566 }
1567 }
1568 }
1569 }
1570 Ok(false)
1571}
1572
1573#[derive(Default, Debug, Clone)]
1574pub struct TypeSystem {
1575 pub env: TypeEnv,
1576 pub classes: ClassEnv,
1577 pub adts: HashMap<Symbol, AdtDecl>,
1578 pub class_info: HashMap<Symbol, ClassInfo>,
1579 pub class_methods: HashMap<Symbol, ClassMethodInfo>,
1580 pub declared_values: HashSet<Symbol>,
1585 pub supply: TypeVarSupply,
1586 limits: TypeSystemLimits,
1587}
1588
1589#[derive(Clone, Debug)]
1599pub struct ClassInfo {
1600 pub name: Symbol,
1601 pub params: Vec<Symbol>,
1602 pub supers: Vec<Symbol>,
1603 pub methods: BTreeMap<Symbol, Scheme>,
1604}
1605
1606#[derive(Clone, Debug)]
1607pub struct ClassMethodInfo {
1608 pub class: Symbol,
1609 pub scheme: Scheme,
1610}
1611
1612#[derive(Clone, Debug)]
1613pub struct PreparedInstanceDecl {
1614 pub span: Span,
1615 pub class: Symbol,
1616 pub head: Type,
1617 pub context: Vec<Predicate>,
1618}
1619
1620impl TypeSystem {
1621 pub fn new() -> Self {
1622 Self {
1623 env: TypeEnv::new(),
1624 classes: ClassEnv::new(),
1625 adts: HashMap::new(),
1626 class_info: HashMap::new(),
1627 class_methods: HashMap::new(),
1628 declared_values: HashSet::new(),
1629 supply: TypeVarSupply::new(),
1630 limits: TypeSystemLimits::default(),
1631 }
1632 }
1633
1634 pub fn fresh_type_var(&mut self, name: Option<Symbol>) -> TypeVar {
1635 self.supply.fresh(name)
1636 }
1637
1638 pub fn set_limits(&mut self, limits: TypeSystemLimits) {
1639 self.limits = limits;
1640 }
1641
1642 pub fn new_with_prelude() -> Result<Self, TypeError> {
1643 let mut ts = TypeSystem::new();
1644 prelude::build_prelude(&mut ts)?;
1645 Ok(ts)
1646 }
1647
1648 fn register_decl(&mut self, decl: &Decl) -> Result<(), TypeError> {
1649 match decl {
1650 Decl::Type(ty) => self.register_type_decl(ty),
1651 Decl::Class(class_decl) => self.register_class_decl(class_decl),
1652 Decl::Instance(inst_decl) => {
1653 let _ = self.register_instance_decl(inst_decl)?;
1654 Ok(())
1655 }
1656 Decl::Fn(fd) => self.register_fn_decls(std::slice::from_ref(fd)),
1657 Decl::DeclareFn(fd) => self.inject_declare_fn_decl(fd),
1658 Decl::Import(..) => Ok(()),
1659 }
1660 }
1661
1662 pub fn register_decls(&mut self, decls: &[Decl]) -> Result<(), TypeError> {
1663 let mut pending_fns: Vec<FnDecl> = Vec::new();
1664 for decl in decls {
1665 if let Decl::Fn(fd) = decl {
1666 pending_fns.push(fd.clone());
1667 continue;
1668 }
1669
1670 if !pending_fns.is_empty() {
1671 self.register_fn_decls(&pending_fns)?;
1672 pending_fns.clear();
1673 }
1674
1675 self.register_decl(decl)?;
1676 }
1677 if !pending_fns.is_empty() {
1678 self.register_fn_decls(&pending_fns)?;
1679 }
1680 Ok(())
1681 }
1682
1683 pub fn add_value(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1684 let name = sym(name.as_ref());
1685 self.declared_values.remove(&name);
1686 self.env.extend(name, scheme);
1687 }
1688
1689 pub fn add_overload(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1690 let name = sym(name.as_ref());
1691 self.declared_values.remove(&name);
1692 self.env.extend_overload(name, scheme);
1693 }
1694
1695 pub fn register_instance(&mut self, class: impl AsRef<str>, inst: Instance) {
1696 self.classes.add_instance(sym(class.as_ref()), inst);
1697 }
1698
1699 pub fn register_class_decl(&mut self, decl: &ClassDecl) -> Result<(), TypeError> {
1700 let span = decl.span;
1701 (|| {
1702 if self.class_info.contains_key(&decl.name)
1706 || self.classes.classes.contains_key(&decl.name)
1707 {
1708 return Err(TypeError::DuplicateClass(decl.name.clone()));
1709 }
1710 if decl.params.is_empty() {
1711 return Err(TypeError::InvalidClassArity {
1712 class: decl.name.clone(),
1713 got: decl.params.len(),
1714 });
1715 }
1716 let params = decl.params.clone();
1717
1718 let mut supers = Vec::with_capacity(decl.supers.len());
1724 if !decl.supers.is_empty() && params.len() != 1 {
1725 return Err(TypeError::UnsupportedExpr(
1726 "multi-parameter classes cannot declare superclasses yet",
1727 ));
1728 }
1729 for sup in &decl.supers {
1730 let mut vars = HashMap::new();
1731 let param = params[0].clone();
1732 let param_tv = self.supply.fresh(Some(param.clone()));
1733 vars.insert(param, param_tv.clone());
1734 let sup_ty = type_from_annotation_expr_vars(
1735 &self.adts,
1736 &sup.typ,
1737 &mut vars,
1738 &mut self.supply,
1739 )?;
1740 if sup_ty != Type::var(param_tv) {
1741 return Err(TypeError::UnsupportedExpr(
1742 "superclass constraints must be of the form `<= C a`",
1743 ));
1744 }
1745 supers.push(sup.class.to_dotted_symbol());
1746 }
1747
1748 self.classes.add_class(decl.name.clone(), supers.clone());
1749
1750 let mut methods = BTreeMap::new();
1751 for ClassMethodSig { name, typ } in &decl.methods {
1752 if self.env.lookup(name).is_some() || self.class_methods.contains_key(name) {
1753 return Err(TypeError::DuplicateClassMethod(name.clone()));
1754 }
1755
1756 let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1757 let mut param_tvs: Vec<TypeVar> = Vec::with_capacity(params.len());
1758 for param in ¶ms {
1759 let tv = self.supply.fresh(Some(param.clone()));
1760 vars.insert(param.clone(), tv.clone());
1761 param_tvs.push(tv);
1762 }
1763
1764 let ty =
1765 type_from_annotation_expr_vars(&self.adts, typ, &mut vars, &mut self.supply)?;
1766
1767 let mut scheme_vars: Vec<TypeVar> = vars.values().cloned().collect();
1768 scheme_vars.sort_by_key(|tv| tv.id);
1769 scheme_vars.dedup_by_key(|tv| tv.id);
1770
1771 let class_pred = Predicate {
1772 class: decl.name.clone(),
1773 typ: if param_tvs.len() == 1 {
1774 Type::var(param_tvs[0].clone())
1775 } else {
1776 Type::tuple(param_tvs.into_iter().map(Type::var).collect())
1777 },
1778 };
1779 let scheme = Scheme::new(scheme_vars, vec![class_pred], ty);
1780
1781 self.env.extend(name.clone(), scheme.clone());
1782 self.class_methods.insert(
1783 name.clone(),
1784 ClassMethodInfo {
1785 class: decl.name.clone(),
1786 scheme: scheme.clone(),
1787 },
1788 );
1789 methods.insert(name.clone(), scheme);
1790 }
1791
1792 self.class_info.insert(
1793 decl.name.clone(),
1794 ClassInfo {
1795 name: decl.name.clone(),
1796 params,
1797 supers,
1798 methods,
1799 },
1800 );
1801 Ok(())
1802 })()
1803 .map_err(|err| with_span(&span, err))
1804 }
1805
1806 pub fn register_instance_decl(
1807 &mut self,
1808 decl: &InstanceDecl,
1809 ) -> Result<PreparedInstanceDecl, TypeError> {
1810 let span = decl.span;
1811 (|| {
1812 let class = decl.class.clone();
1813 if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
1814 return Err(TypeError::UnknownClass(class));
1815 }
1816
1817 let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1818 let head = type_from_annotation_expr_vars(
1819 &self.adts,
1820 &decl.head,
1821 &mut vars,
1822 &mut self.supply,
1823 )?;
1824 let context = predicates_from_constraints(
1825 &self.adts,
1826 &decl.context,
1827 &mut vars,
1828 &mut self.supply,
1829 )?;
1830
1831 let inst = Instance::new(
1832 context.clone(),
1833 Predicate {
1834 class: decl.class.clone(),
1835 typ: head.clone(),
1836 },
1837 );
1838
1839 if let Some(info) = self.class_info.get(&decl.class) {
1841 for method in &decl.methods {
1842 if !info.methods.contains_key(&method.name) {
1843 return Err(TypeError::UnknownInstanceMethod {
1844 class: decl.class.clone(),
1845 method: method.name.clone(),
1846 });
1847 }
1848 }
1849 for method_name in info.methods.keys() {
1850 if !decl.methods.iter().any(|m| &m.name == method_name) {
1851 return Err(TypeError::MissingInstanceMethod {
1852 class: decl.class.clone(),
1853 method: method_name.clone(),
1854 });
1855 }
1856 }
1857 }
1858
1859 self.classes.add_instance(decl.class.clone(), inst);
1860 Ok(PreparedInstanceDecl {
1861 span,
1862 class: decl.class.clone(),
1863 head,
1864 context,
1865 })
1866 })()
1867 .map_err(|err| with_span(&span, err))
1868 }
1869
1870 pub fn prepare_instance_decl(
1871 &mut self,
1872 decl: &InstanceDecl,
1873 ) -> Result<PreparedInstanceDecl, TypeError> {
1874 let span = decl.span;
1875 (|| {
1876 let class = decl.class.clone();
1877 if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
1878 return Err(TypeError::UnknownClass(class));
1879 }
1880
1881 let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1882 let head = type_from_annotation_expr_vars(
1883 &self.adts,
1884 &decl.head,
1885 &mut vars,
1886 &mut self.supply,
1887 )?;
1888 let context = predicates_from_constraints(
1889 &self.adts,
1890 &decl.context,
1891 &mut vars,
1892 &mut self.supply,
1893 )?;
1894
1895 if let Some(info) = self.class_info.get(&decl.class) {
1897 for method in &decl.methods {
1898 if !info.methods.contains_key(&method.name) {
1899 return Err(TypeError::UnknownInstanceMethod {
1900 class: decl.class.clone(),
1901 method: method.name.clone(),
1902 });
1903 }
1904 }
1905 for method_name in info.methods.keys() {
1906 if !decl.methods.iter().any(|m| &m.name == method_name) {
1907 return Err(TypeError::MissingInstanceMethod {
1908 class: decl.class.clone(),
1909 method: method_name.clone(),
1910 });
1911 }
1912 }
1913 }
1914
1915 Ok(PreparedInstanceDecl {
1916 span,
1917 class: decl.class.clone(),
1918 head,
1919 context,
1920 })
1921 })()
1922 .map_err(|err| with_span(&span, err))
1923 }
1924
1925 pub fn register_fn_decls(&mut self, decls: &[FnDecl]) -> Result<(), TypeError> {
1926 if decls.is_empty() {
1927 return Ok(());
1928 }
1929
1930 let saved_env = self.env.clone();
1931 let saved_declared = self.declared_values.clone();
1932
1933 let result: Result<(), TypeError> = (|| {
1934 #[derive(Clone)]
1935 struct FnInfo {
1936 decl: FnDecl,
1937 expected: Type,
1938 declared_preds: Vec<Predicate>,
1939 scheme: Scheme,
1940 ann_vars: HashMap<Symbol, TypeVar>,
1941 }
1942
1943 let mut infos: Vec<FnInfo> = Vec::with_capacity(decls.len());
1944 let mut seen_names = HashSet::new();
1945
1946 for decl in decls {
1947 let span = decl.span;
1948 let info = (|| {
1949 let name = &decl.name.name;
1950 if !seen_names.insert(name.clone()) {
1951 return Err(TypeError::DuplicateValue(name.clone()));
1952 }
1953
1954 if self.env.lookup(name).is_some() {
1955 if self.declared_values.remove(name) {
1956 self.env.remove(name);
1958 } else {
1959 return Err(TypeError::DuplicateValue(name.clone()));
1960 }
1961 }
1962
1963 let mut sig = decl.ret.clone();
1964 for (_, ann) in decl.params.iter().rev() {
1965 let span = Span::from_begin_end(ann.span().begin, sig.span().end);
1966 sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
1967 }
1968
1969 let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
1970 let expected = type_from_annotation_expr_vars(
1971 &self.adts,
1972 &sig,
1973 &mut ann_vars,
1974 &mut self.supply,
1975 )?;
1976 let declared_preds = predicates_from_constraints(
1977 &self.adts,
1978 &decl.constraints,
1979 &mut ann_vars,
1980 &mut self.supply,
1981 )?;
1982
1983 let var_arities: HashMap<TypeVarId, usize> = ann_vars
1985 .values()
1986 .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
1987 .collect();
1988 for pred in &declared_preds {
1989 let _ = entails(&self.classes, &[], pred)?;
1990 let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
1991 else {
1992 continue;
1993 };
1994 let args: Vec<Type> = if expected_arities.len() == 1 {
1995 vec![pred.typ.clone()]
1996 } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
1997 if parts.len() != expected_arities.len() {
1998 continue;
1999 }
2000 parts.clone()
2001 } else {
2002 continue;
2003 };
2004
2005 for (arg, expected_arity) in
2006 args.iter().zip(expected_arities.iter().copied())
2007 {
2008 let got =
2009 type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2010 TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2011 _ => None,
2012 });
2013 let Some(got) = got else {
2014 continue;
2015 };
2016 if got != expected_arity {
2017 return Err(TypeError::KindMismatch {
2018 class: pred.class.clone(),
2019 expected: expected_arity,
2020 got,
2021 typ: arg.to_string(),
2022 });
2023 }
2024 }
2025 }
2026
2027 let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2028 vars.sort_by_key(|v| v.id);
2029 let scheme = Scheme::new(vars, declared_preds.clone(), expected.clone());
2030 reject_ambiguous_scheme(&scheme)?;
2031
2032 Ok(FnInfo {
2033 decl: decl.clone(),
2034 expected,
2035 declared_preds,
2036 scheme,
2037 ann_vars,
2038 })
2039 })();
2040
2041 infos.push(info.map_err(|err| with_span(&span, err))?);
2042 }
2043
2044 for info in &infos {
2047 self.env
2048 .extend(info.decl.name.name.clone(), info.scheme.clone());
2049 }
2050
2051 for info in infos {
2052 let span = info.decl.span;
2053 let mut lam_body = info.decl.body.clone();
2054 let mut lam_end = lam_body.span().end;
2055 for (param, ann) in info.decl.params.iter().rev() {
2056 let lam_constraints = Vec::new();
2057 let span = Span::from_begin_end(param.span.begin, lam_end);
2058 lam_body = Arc::new(Expr::Lam(
2059 span,
2060 Scope::new_sync(),
2061 param.clone(),
2062 Some(ann.clone()),
2063 lam_constraints,
2064 lam_body,
2065 ));
2066 lam_end = lam_body.span().end;
2067 }
2068
2069 let (typed, preds, inferred) = infer_typed(self, lam_body.as_ref())?;
2070 let s = unify(&inferred, &info.expected)?;
2071 let preds = preds.apply(&s);
2072 let inferred = inferred.apply(&s);
2073 let declared_preds = info.declared_preds.apply(&s);
2074 let expected = info.expected.apply(&s);
2075
2076 let var_arities: HashMap<TypeVarId, usize> = info
2078 .ann_vars
2079 .values()
2080 .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
2081 .collect();
2082 for pred in &declared_preds {
2083 let _ = entails(&self.classes, &[], pred)?;
2084 let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
2085 else {
2086 continue;
2087 };
2088 let args: Vec<Type> = if expected_arities.len() == 1 {
2089 vec![pred.typ.clone()]
2090 } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2091 if parts.len() != expected_arities.len() {
2092 continue;
2093 }
2094 parts.clone()
2095 } else {
2096 continue;
2097 };
2098
2099 for (arg, expected_arity) in args.iter().zip(expected_arities.iter().copied()) {
2100 let got = type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2101 TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2102 _ => None,
2103 });
2104 let Some(got) = got else {
2105 continue;
2106 };
2107 if got != expected_arity {
2108 return Err(with_span(
2109 &span,
2110 TypeError::KindMismatch {
2111 class: pred.class.clone(),
2112 expected: expected_arity,
2113 got,
2114 typ: arg.to_string(),
2115 },
2116 ));
2117 }
2118 }
2119 }
2120
2121 check_non_ground_predicates_declared(&self.classes, &declared_preds, &preds)
2122 .map_err(|err| with_span(&span, err))?;
2123
2124 let _ = inferred;
2125 let _ = typed;
2126 }
2127
2128 Ok(())
2129 })();
2130
2131 if result.is_err() {
2132 self.env = saved_env;
2133 self.declared_values = saved_declared;
2134 }
2135 result
2136 }
2137
2138 pub fn inject_declare_fn_decl(&mut self, decl: &DeclareFnDecl) -> Result<(), TypeError> {
2139 let span = decl.span;
2140 (|| {
2141 let mut sig = decl.ret.clone();
2143 for (_, ann) in decl.params.iter().rev() {
2144 let span = Span::from_begin_end(ann.span().begin, sig.span().end);
2145 sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
2146 }
2147
2148 let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
2149 let expected =
2150 type_from_annotation_expr_vars(&self.adts, &sig, &mut ann_vars, &mut self.supply)?;
2151 let declared_preds = predicates_from_constraints(
2152 &self.adts,
2153 &decl.constraints,
2154 &mut ann_vars,
2155 &mut self.supply,
2156 )?;
2157
2158 let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2159 vars.sort_by_key(|v| v.id);
2160 let scheme = Scheme::new(vars, declared_preds, expected);
2161 reject_ambiguous_scheme(&scheme)?;
2162
2163 for pred in &scheme.preds {
2165 let _ = entails(&self.classes, &[], pred)?;
2166 }
2167
2168 let name = &decl.name.name;
2169
2170 if self.env.lookup(name).is_some() && !self.declared_values.contains(name) {
2173 return Ok(());
2174 }
2175
2176 if let Some(existing) = self.env.lookup(name) {
2177 if existing.iter().any(|s| scheme_compatible(s, &scheme)) {
2178 return Ok(());
2179 }
2180 return Err(TypeError::DuplicateValue(decl.name.name.clone()));
2181 }
2182
2183 self.env.extend(decl.name.name.clone(), scheme);
2184 self.declared_values.insert(decl.name.name.clone());
2185 Ok(())
2186 })()
2187 .map_err(|err| with_span(&span, err))
2188 }
2189
2190 pub fn instantiate_class_method_for_head(
2191 &mut self,
2192 class: &Symbol,
2193 method: &Symbol,
2194 head: &Type,
2195 ) -> Result<Type, TypeError> {
2196 let info = self
2197 .class_info
2198 .get(class)
2199 .ok_or_else(|| TypeError::UnknownClass(class.clone()))?;
2200 let scheme = info
2201 .methods
2202 .get(method)
2203 .ok_or_else(|| TypeError::UnknownInstanceMethod {
2204 class: class.clone(),
2205 method: method.clone(),
2206 })?;
2207
2208 let (preds, typ) = instantiate(scheme, &mut self.supply);
2209 let class_pred =
2210 preds
2211 .iter()
2212 .find(|p| &p.class == class)
2213 .ok_or(TypeError::UnsupportedExpr(
2214 "class method scheme missing class predicate",
2215 ))?;
2216 let s = unify(&class_pred.typ, head)?;
2217 Ok(typ.apply(&s))
2218 }
2219
2220 pub fn typecheck_instance_method(
2221 &mut self,
2222 prepared: &PreparedInstanceDecl,
2223 method: &InstanceMethodImpl,
2224 ) -> Result<TypedExpr, TypeError> {
2225 let expected =
2226 self.instantiate_class_method_for_head(&prepared.class, &method.name, &prepared.head)?;
2227 let (typed, preds, actual) = infer_typed(self, method.body.as_ref())?;
2228 let s = unify(&actual, &expected)?;
2229 let typed = typed.apply(&s);
2230 let preds = preds.apply(&s);
2231
2232 let mut given = prepared.context.clone();
2238
2239 given.push(Predicate::new(
2242 prepared.class.clone(),
2243 prepared.head.clone(),
2244 ));
2245 let mut i = 0;
2246 while i < given.len() {
2247 let p = given[i].clone();
2248 for sup in self.classes.supers_of(&p.class) {
2249 given.push(Predicate::new(sup, p.typ.clone()));
2250 }
2251 i += 1;
2252 }
2253
2254 for pred in &preds {
2255 if pred.typ.ftv().is_empty() {
2256 if !entails(&self.classes, &given, pred)? {
2257 return Err(TypeError::NoInstance(
2258 pred.class.clone(),
2259 pred.typ.to_string(),
2260 ));
2261 }
2262 } else if !given
2263 .iter()
2264 .any(|p| p.class == pred.class && p.typ == pred.typ)
2265 {
2266 return Err(TypeError::MissingInstanceConstraint {
2267 method: method.name.clone(),
2268 class: pred.class.clone(),
2269 typ: pred.typ.to_string(),
2270 });
2271 }
2272 }
2273
2274 Ok(typed)
2275 }
2276
2277 pub fn register_adt(&mut self, adt: &AdtDecl) {
2281 self.adts.insert(adt.name.clone(), adt.clone());
2282 for (name, scheme) in adt.constructor_schemes() {
2283 self.register_value_scheme(&name, scheme);
2284 }
2285 }
2286
2287 pub fn adt_from_decl(&mut self, decl: &TypeDecl) -> Result<AdtDecl, TypeError> {
2288 let mut adt = AdtDecl::new(&decl.name, &decl.params, &mut self.supply);
2289 let mut param_map: HashMap<Symbol, TypeVar> = HashMap::new();
2290 for param in &adt.params {
2291 param_map.insert(param.name.clone(), param.var.clone());
2292 }
2293
2294 for variant in &decl.variants {
2295 let mut args = Vec::new();
2296 for arg in &variant.args {
2297 let ty = self.type_from_expr(decl, ¶m_map, arg)?;
2298 args.push(ty);
2299 }
2300 adt.add_variant(variant.name.clone(), args);
2301 }
2302 Ok(adt)
2303 }
2304
2305 pub fn register_type_decl(&mut self, decl: &TypeDecl) -> Result<(), TypeError> {
2306 if BuiltinTypeId::from_symbol(&decl.name).is_some() {
2307 return Err(TypeError::ReservedTypeName(decl.name.clone()));
2308 }
2309 let adt = self.adt_from_decl(decl)?;
2310 self.register_adt(&adt);
2311 Ok(())
2312 }
2313
2314 fn type_from_expr(
2315 &mut self,
2316 decl: &TypeDecl,
2317 params: &HashMap<Symbol, TypeVar>,
2318 expr: &TypeExpr,
2319 ) -> Result<Type, TypeError> {
2320 let span = *expr.span();
2321 let res = (|| match expr {
2322 TypeExpr::Name(_, name) => {
2323 let name_sym = name.to_dotted_symbol();
2324 if let Some(tv) = params.get(&name_sym) {
2325 Ok(Type::var(tv.clone()))
2326 } else {
2327 let name = normalize_type_name(&name_sym);
2328 if let Some(arity) = self.type_arity(decl, &name) {
2329 Ok(Type::con(name, arity))
2330 } else {
2331 Err(TypeError::UnknownTypeName(name))
2332 }
2333 }
2334 }
2335 TypeExpr::App(_, fun, arg) => {
2336 let fty = self.type_from_expr(decl, params, fun)?;
2337 let aty = self.type_from_expr(decl, params, arg)?;
2338 Ok(type_app_with_result_syntax(fty, aty))
2339 }
2340 TypeExpr::Fun(_, arg, ret) => {
2341 let arg_ty = self.type_from_expr(decl, params, arg)?;
2342 let ret_ty = self.type_from_expr(decl, params, ret)?;
2343 Ok(Type::fun(arg_ty, ret_ty))
2344 }
2345 TypeExpr::Tuple(_, elems) => {
2346 let mut out = Vec::new();
2347 for elem in elems {
2348 out.push(self.type_from_expr(decl, params, elem)?);
2349 }
2350 Ok(Type::tuple(out))
2351 }
2352 TypeExpr::Record(_, fields) => {
2353 let mut out = Vec::new();
2354 for (name, ty) in fields {
2355 out.push((name.clone(), self.type_from_expr(decl, params, ty)?));
2356 }
2357 Ok(Type::record(out))
2358 }
2359 })();
2360 res.map_err(|err| with_span(&span, err))
2361 }
2362
2363 fn type_arity(&self, decl: &TypeDecl, name: &Symbol) -> Option<usize> {
2364 if &decl.name == name {
2365 return Some(decl.params.len());
2366 }
2367 if let Some(adt) = self.adts.get(name) {
2368 return Some(adt.params.len());
2369 }
2370 BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2371 }
2372
2373 fn register_value_scheme(&mut self, name: &Symbol, scheme: Scheme) {
2374 match self.env.lookup(name) {
2375 None => self.env.extend(name.clone(), scheme),
2376 Some(existing) => {
2377 if existing.iter().any(|s| unify(&s.typ, &scheme.typ).is_ok()) {
2378 return;
2379 }
2380 self.env.extend_overload(name.clone(), scheme);
2381 }
2382 }
2383 }
2384
2385 fn expected_class_param_arities(&self, class: &Symbol) -> Option<Vec<usize>> {
2386 let info = self.class_info.get(class)?;
2387 let mut out = vec![0usize; info.params.len()];
2388 for scheme in info.methods.values() {
2389 for (idx, param) in info.params.iter().enumerate() {
2390 let Some(tv) = scheme.vars.iter().find(|v| v.name.as_ref() == Some(param)) else {
2391 continue;
2392 };
2393 out[idx] = out[idx].max(max_head_app_arity_for_var(&scheme.typ, tv.id));
2394 }
2395 }
2396 Some(out)
2397 }
2398
2399 fn check_predicate_kind(&self, pred: &Predicate) -> Result<(), TypeError> {
2400 let Some(expected) = self.expected_class_param_arities(&pred.class) else {
2401 return Ok(());
2403 };
2404
2405 let args: Vec<Type> = if expected.len() == 1 {
2406 vec![pred.typ.clone()]
2407 } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2408 if parts.len() != expected.len() {
2409 return Ok(());
2410 }
2411 parts.clone()
2412 } else {
2413 return Ok(());
2414 };
2415
2416 for (arg, expected_arity) in args.iter().zip(expected.iter().copied()) {
2417 let Some(got) = type_term_remaining_arity(arg) else {
2418 continue;
2422 };
2423 if got != expected_arity {
2424 return Err(TypeError::KindMismatch {
2425 class: pred.class.clone(),
2426 expected: expected_arity,
2427 got,
2428 typ: arg.to_string(),
2429 });
2430 }
2431 }
2432 Ok(())
2433 }
2434
2435 fn check_predicate_kinds(&self, preds: &[Predicate]) -> Result<(), TypeError> {
2436 for pred in preds {
2437 self.check_predicate_kind(pred)?;
2438 }
2439 Ok(())
2440 }
2441}
2442
2443fn type_from_annotation_expr(
2444 adts: &HashMap<Symbol, AdtDecl>,
2445 expr: &TypeExpr,
2446) -> Result<Type, TypeError> {
2447 let span = *expr.span();
2448 let res = (|| match expr {
2449 TypeExpr::Name(_, name) => {
2450 let name = normalize_type_name(&name.to_dotted_symbol());
2451 match annotation_type_arity(adts, &name) {
2452 Some(arity) => Ok(Type::con(name, arity)),
2453 None => Err(TypeError::UnknownTypeName(name)),
2454 }
2455 }
2456 TypeExpr::App(_, fun, arg) => {
2457 let fty = type_from_annotation_expr(adts, fun)?;
2458 let aty = type_from_annotation_expr(adts, arg)?;
2459 Ok(type_app_with_result_syntax(fty, aty))
2460 }
2461 TypeExpr::Fun(_, arg, ret) => {
2462 let arg_ty = type_from_annotation_expr(adts, arg)?;
2463 let ret_ty = type_from_annotation_expr(adts, ret)?;
2464 Ok(Type::fun(arg_ty, ret_ty))
2465 }
2466 TypeExpr::Tuple(_, elems) => {
2467 let mut out = Vec::new();
2468 for elem in elems {
2469 out.push(type_from_annotation_expr(adts, elem)?);
2470 }
2471 Ok(Type::tuple(out))
2472 }
2473 TypeExpr::Record(_, fields) => {
2474 let mut out = Vec::new();
2475 for (name, ty) in fields {
2476 out.push((name.clone(), type_from_annotation_expr(adts, ty)?));
2477 }
2478 Ok(Type::record(out))
2479 }
2480 })();
2481 res.map_err(|err| with_span(&span, err))
2482}
2483
2484fn type_from_annotation_expr_vars(
2485 adts: &HashMap<Symbol, AdtDecl>,
2486 expr: &TypeExpr,
2487 vars: &mut HashMap<Symbol, TypeVar>,
2488 supply: &mut TypeVarSupply,
2489) -> Result<Type, TypeError> {
2490 let span = *expr.span();
2491 let res = (|| match expr {
2492 TypeExpr::Name(_, name) => {
2493 let name = normalize_type_name(&name.to_dotted_symbol());
2494 if let Some(arity) = annotation_type_arity(adts, &name) {
2495 Ok(Type::con(name, arity))
2496 } else if let Some(tv) = vars.get(&name) {
2497 Ok(Type::var(tv.clone()))
2498 } else {
2499 let is_upper = name
2500 .chars()
2501 .next()
2502 .map(|c| c.is_uppercase())
2503 .unwrap_or(false);
2504 if is_upper {
2505 return Err(TypeError::UnknownTypeName(name));
2506 }
2507 let tv = supply.fresh(Some(name.clone()));
2508 vars.insert(name.clone(), tv.clone());
2509 Ok(Type::var(tv))
2510 }
2511 }
2512 TypeExpr::App(_, fun, arg) => {
2513 let fty = type_from_annotation_expr_vars(adts, fun, vars, supply)?;
2514 let aty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2515 Ok(type_app_with_result_syntax(fty, aty))
2516 }
2517 TypeExpr::Fun(_, arg, ret) => {
2518 let arg_ty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2519 let ret_ty = type_from_annotation_expr_vars(adts, ret, vars, supply)?;
2520 Ok(Type::fun(arg_ty, ret_ty))
2521 }
2522 TypeExpr::Tuple(_, elems) => {
2523 let mut out = Vec::new();
2524 for elem in elems {
2525 out.push(type_from_annotation_expr_vars(adts, elem, vars, supply)?);
2526 }
2527 Ok(Type::tuple(out))
2528 }
2529 TypeExpr::Record(_, fields) => {
2530 let mut out = Vec::new();
2531 for (name, ty) in fields {
2532 out.push((
2533 name.clone(),
2534 type_from_annotation_expr_vars(adts, ty, vars, supply)?,
2535 ));
2536 }
2537 Ok(Type::record(out))
2538 }
2539 })();
2540 res.map_err(|err| with_span(&span, err))
2541}
2542
2543fn annotation_type_arity(adts: &HashMap<Symbol, AdtDecl>, name: &Symbol) -> Option<usize> {
2544 if let Some(adt) = adts.get(name) {
2545 return Some(adt.params.len());
2546 }
2547 BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2548}
2549
2550fn normalize_type_name(name: &Symbol) -> Symbol {
2551 if name.as_ref() == "str" {
2552 BuiltinTypeId::String.as_symbol()
2553 } else {
2554 name.clone()
2555 }
2556}
2557
2558fn type_app_with_result_syntax(fun: Type, arg: Type) -> Type {
2559 if let TypeKind::App(head, ok) = fun.as_ref()
2560 && matches!(
2561 head.as_ref(),
2562 TypeKind::Con(c)
2563 if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
2564 )
2565 {
2566 return Type::app(Type::app(head.clone(), arg), ok.clone());
2567 }
2568 Type::app(fun, arg)
2569}
2570
2571fn predicates_from_constraints(
2572 adts: &HashMap<Symbol, AdtDecl>,
2573 constraints: &[TypeConstraint],
2574 vars: &mut HashMap<Symbol, TypeVar>,
2575 supply: &mut TypeVarSupply,
2576) -> Result<Vec<Predicate>, TypeError> {
2577 let mut out = Vec::with_capacity(constraints.len());
2578 for constraint in constraints {
2579 let ty = type_from_annotation_expr_vars(adts, &constraint.typ, vars, supply)?;
2580 out.push(Predicate::new(constraint.class.as_ref(), ty));
2581 }
2582 Ok(out)
2583}
2584
2585#[derive(Clone, Debug, PartialEq, Eq)]
2586pub struct AdtConflict {
2587 pub name: Symbol,
2588 pub definitions: Vec<Type>,
2589}
2590
2591#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
2592#[error("conflicting ADT definitions: {conflicts:?}")]
2593pub struct CollectAdtsError {
2594 pub conflicts: Vec<AdtConflict>,
2595}
2596
2597pub fn collect_adts_in_types(types: Vec<Type>) -> Result<Vec<Type>, CollectAdtsError> {
2634 fn visit(
2635 typ: &Type,
2636 out: &mut Vec<Type>,
2637 seen: &mut HashSet<Type>,
2638 defs_by_name: &mut BTreeMap<Symbol, Vec<Type>>,
2639 ) {
2640 match typ.as_ref() {
2641 TypeKind::Var(_) => {}
2642 TypeKind::Con(tc) => {
2643 if tc.builtin_id.is_none() {
2645 let adt = Type::new(TypeKind::Con(tc.clone()));
2646 if seen.insert(adt.clone()) {
2647 out.push(adt.clone());
2648 }
2649 let defs = defs_by_name.entry(tc.name.clone()).or_default();
2650 if !defs.contains(&adt) {
2651 defs.push(adt);
2652 }
2653 }
2654 }
2655 TypeKind::App(fun, arg) => {
2656 visit(fun, out, seen, defs_by_name);
2657 visit(arg, out, seen, defs_by_name);
2658 }
2659 TypeKind::Fun(arg, ret) => {
2660 visit(arg, out, seen, defs_by_name);
2661 visit(ret, out, seen, defs_by_name);
2662 }
2663 TypeKind::Tuple(elems) => {
2664 for elem in elems {
2665 visit(elem, out, seen, defs_by_name);
2666 }
2667 }
2668 TypeKind::Record(fields) => {
2669 for (_name, field_ty) in fields {
2670 visit(field_ty, out, seen, defs_by_name);
2671 }
2672 }
2673 }
2674 }
2675
2676 let mut out = Vec::new();
2677 let mut seen = HashSet::new();
2678 let mut defs_by_name: BTreeMap<Symbol, Vec<Type>> = BTreeMap::new();
2679 for typ in &types {
2680 visit(typ, &mut out, &mut seen, &mut defs_by_name);
2681 }
2682
2683 let conflicts: Vec<AdtConflict> = defs_by_name
2684 .into_iter()
2685 .filter_map(|(name, definitions)| {
2686 (definitions.len() > 1).then_some(AdtConflict { name, definitions })
2687 })
2688 .collect();
2689 if !conflicts.is_empty() {
2690 return Err(CollectAdtsError { conflicts });
2691 }
2692
2693 Ok(out)
2694}