1use std::collections::{BTreeMap, HashMap, HashSet};
4use std::fmt::{self, Display, Formatter};
5use std::sync::Arc;
6
7use chrono::{DateTime, Utc};
8use rexlang_ast::expr::{
9 ClassDecl, ClassMethodSig, Decl, DeclareFnDecl, Expr, FnDecl, InstanceDecl, InstanceMethodImpl,
10 Pattern, Scope, Symbol, TypeConstraint, TypeDecl, TypeExpr, intern, sym,
11};
12use rexlang_lexer::span::Span;
13use rexlang_util::{GasMeter, OutOfGas};
14use rpds::HashTrieMapSync;
15use uuid::Uuid;
16
17use crate::prelude;
18
19pub type TypeVarId = usize;
20
21#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
22pub enum BuiltinTypeId {
23 U8,
24 U16,
25 U32,
26 U64,
27 I8,
28 I16,
29 I32,
30 I64,
31 F32,
32 F64,
33 Bool,
34 String,
35 Uuid,
36 DateTime,
37 List,
38 Array,
39 Dict,
40 Option,
41 Result,
42}
43
44impl BuiltinTypeId {
45 pub fn as_symbol(self) -> Symbol {
46 sym(self.as_str())
47 }
48
49 pub fn as_str(self) -> &'static str {
50 match self {
51 Self::U8 => "u8",
52 Self::U16 => "u16",
53 Self::U32 => "u32",
54 Self::U64 => "u64",
55 Self::I8 => "i8",
56 Self::I16 => "i16",
57 Self::I32 => "i32",
58 Self::I64 => "i64",
59 Self::F32 => "f32",
60 Self::F64 => "f64",
61 Self::Bool => "bool",
62 Self::String => "string",
63 Self::Uuid => "uuid",
64 Self::DateTime => "datetime",
65 Self::List => "List",
66 Self::Array => "Array",
67 Self::Dict => "Dict",
68 Self::Option => "Option",
69 Self::Result => "Result",
70 }
71 }
72
73 pub fn arity(self) -> usize {
74 match self {
75 Self::List | Self::Array | Self::Dict | Self::Option => 1,
76 Self::Result => 2,
77 _ => 0,
78 }
79 }
80
81 pub fn from_symbol(name: &Symbol) -> Option<Self> {
82 Self::from_name(name.as_ref())
83 }
84
85 pub fn from_name(name: &str) -> Option<Self> {
86 match name {
87 "u8" => Some(Self::U8),
88 "u16" => Some(Self::U16),
89 "u32" => Some(Self::U32),
90 "u64" => Some(Self::U64),
91 "i8" => Some(Self::I8),
92 "i16" => Some(Self::I16),
93 "i32" => Some(Self::I32),
94 "i64" => Some(Self::I64),
95 "f32" => Some(Self::F32),
96 "f64" => Some(Self::F64),
97 "bool" => Some(Self::Bool),
98 "string" => Some(Self::String),
99 "uuid" => Some(Self::Uuid),
100 "datetime" => Some(Self::DateTime),
101 "List" => Some(Self::List),
102 "Array" => Some(Self::Array),
103 "Dict" => Some(Self::Dict),
104 "Option" => Some(Self::Option),
105 "Result" => Some(Self::Result),
106 _ => None,
107 }
108 }
109}
110
111#[derive(Clone, Debug, Eq, Hash, PartialEq)]
112pub struct TypeVar {
113 pub id: TypeVarId,
114 pub name: Option<Symbol>,
115}
116
117impl TypeVar {
118 pub fn new(id: TypeVarId, name: impl Into<Option<Symbol>>) -> Self {
119 Self {
120 id,
121 name: name.into(),
122 }
123 }
124}
125
126#[derive(Clone, Debug, Eq, Hash, PartialEq)]
127pub struct TypeConst {
128 pub name: Symbol,
129 pub arity: usize,
130 pub builtin_id: Option<BuiltinTypeId>,
131}
132
133#[derive(Clone, Debug, PartialEq, Eq, Hash)]
134pub struct Type(Arc<TypeKind>);
135
136#[derive(Clone, Debug, PartialEq, Eq, Hash)]
137pub enum TypeKind {
138 Var(TypeVar),
139 Con(TypeConst),
140 App(Type, Type),
141 Fun(Type, Type),
142 Tuple(Vec<Type>),
143 Record(Vec<(Symbol, Type)>),
148}
149
150impl Type {
151 pub fn new(kind: TypeKind) -> Self {
152 Type(Arc::new(kind))
153 }
154
155 pub fn con(name: impl AsRef<str>, arity: usize) -> Self {
156 if let Some(id) = BuiltinTypeId::from_name(name.as_ref())
157 && id.arity() == arity
158 {
159 return Self::builtin(id);
160 }
161 Self::user_con(name, arity)
162 }
163
164 pub fn user_con(name: impl AsRef<str>, arity: usize) -> Self {
165 Type::new(TypeKind::Con(TypeConst {
166 name: intern(name.as_ref()),
167 arity,
168 builtin_id: None,
169 }))
170 }
171
172 pub fn builtin(id: BuiltinTypeId) -> Self {
173 Type::new(TypeKind::Con(TypeConst {
174 name: id.as_symbol(),
175 arity: id.arity(),
176 builtin_id: Some(id),
177 }))
178 }
179
180 pub fn var(tv: TypeVar) -> Self {
181 Type::new(TypeKind::Var(tv))
182 }
183
184 pub fn fun(a: Type, b: Type) -> Self {
185 Type::new(TypeKind::Fun(a, b))
186 }
187
188 pub fn app(f: Type, arg: Type) -> Self {
189 Type::new(TypeKind::App(f, arg))
190 }
191
192 pub fn tuple(elems: Vec<Type>) -> Self {
193 Type::new(TypeKind::Tuple(elems))
194 }
195
196 pub fn record(mut fields: Vec<(Symbol, Type)>) -> Self {
197 fields.sort_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
200 Type::new(TypeKind::Record(fields))
201 }
202
203 pub fn list(elem: Type) -> Type {
204 Type::app(Type::builtin(BuiltinTypeId::List), elem)
205 }
206
207 pub fn array(elem: Type) -> Type {
208 Type::app(Type::builtin(BuiltinTypeId::Array), elem)
209 }
210
211 pub fn dict(elem: Type) -> Type {
212 Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
213 }
214
215 pub fn option(elem: Type) -> Type {
216 Type::app(Type::builtin(BuiltinTypeId::Option), elem)
217 }
218
219 pub fn result(ok: Type, err: Type) -> Type {
220 Type::app(Type::app(Type::builtin(BuiltinTypeId::Result), err), ok)
221 }
222
223 fn apply_with_change(&self, s: &Subst) -> (Type, bool) {
224 match self.as_ref() {
225 TypeKind::Var(tv) => match s.get(&tv.id) {
226 Some(ty) => (ty.clone(), true),
227 None => (self.clone(), false),
228 },
229 TypeKind::Con(_) => (self.clone(), false),
230 TypeKind::App(l, r) => {
231 let (l_new, l_changed) = l.apply_with_change(s);
232 let (r_new, r_changed) = r.apply_with_change(s);
233 if l_changed || r_changed {
234 (Type::app(l_new, r_new), true)
235 } else {
236 (self.clone(), false)
237 }
238 }
239 TypeKind::Fun(_, _) => {
240 let mut args = Vec::new();
243 let mut changed = false;
244 let mut cur: &Type = self;
245 while let TypeKind::Fun(a, b) = cur.as_ref() {
246 let (a_new, a_changed) = a.apply_with_change(s);
247 changed |= a_changed;
248 args.push(a_new);
249 cur = b;
250 }
251 let (ret_new, ret_changed) = cur.apply_with_change(s);
252 changed |= ret_changed;
253 if !changed {
254 return (self.clone(), false);
255 }
256 let mut out = ret_new;
257 for a_new in args.into_iter().rev() {
258 out = Type::fun(a_new, out);
259 }
260 (out, true)
261 }
262 TypeKind::Tuple(ts) => {
263 let mut changed = false;
264 let mut out = Vec::with_capacity(ts.len());
265 for t in ts {
266 let (t_new, t_changed) = t.apply_with_change(s);
267 changed |= t_changed;
268 out.push(t_new);
269 }
270 if changed {
271 (Type::new(TypeKind::Tuple(out)), true)
272 } else {
273 (self.clone(), false)
274 }
275 }
276 TypeKind::Record(fields) => {
277 let mut changed = false;
278 let mut out = Vec::with_capacity(fields.len());
279 for (k, v) in fields {
280 let (v_new, v_changed) = v.apply_with_change(s);
281 changed |= v_changed;
282 out.push((k.clone(), v_new));
283 }
284 if changed {
285 (Type::new(TypeKind::Record(out)), true)
286 } else {
287 (self.clone(), false)
288 }
289 }
290 }
291 }
292}
293
294impl AsRef<TypeKind> for Type {
295 fn as_ref(&self) -> &TypeKind {
296 self.0.as_ref()
297 }
298}
299
300impl std::ops::Deref for Type {
301 type Target = TypeKind;
302
303 fn deref(&self) -> &Self::Target {
304 &self.0
305 }
306}
307
308impl Display for Type {
309 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
310 match self.as_ref() {
311 TypeKind::Var(tv) => match &tv.name {
312 Some(name) => write!(f, "'{}", name),
313 None => write!(f, "t{}", tv.id),
314 },
315 TypeKind::Con(c) => write!(f, "{}", c.name),
316 TypeKind::App(l, r) => {
317 if let TypeKind::App(head, err) = l.as_ref()
323 && matches!(
324 head.as_ref(),
325 TypeKind::Con(c)
326 if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
327 )
328 {
329 return write!(f, "(Result {} {})", r, err);
330 }
331 write!(f, "({} {})", l, r)
332 }
333 TypeKind::Fun(a, b) => write!(f, "({} -> {})", a, b),
334 TypeKind::Tuple(elems) => {
335 write!(f, "(")?;
336 for (i, t) in elems.iter().enumerate() {
337 write!(f, "{}", t)?;
338 if i + 1 < elems.len() {
339 write!(f, ", ")?;
340 }
341 }
342 write!(f, ")")
343 }
344 TypeKind::Record(fields) => {
345 write!(f, "{{")?;
346 for (i, (name, ty)) in fields.iter().enumerate() {
347 write!(f, "{}: {}", name, ty)?;
348 if i + 1 < fields.len() {
349 write!(f, ", ")?;
350 }
351 }
352 write!(f, "}}")
353 }
354 }
355 }
356}
357
358#[derive(Clone, Debug, PartialEq, Eq, Hash)]
359pub struct Predicate {
360 pub class: Symbol,
361 pub typ: Type,
362}
363
364impl Predicate {
365 pub fn new(class: impl AsRef<str>, typ: Type) -> Self {
366 Self {
367 class: intern(class.as_ref()),
368 typ,
369 }
370 }
371}
372
373#[derive(Clone, Debug, PartialEq)]
374pub struct Scheme {
375 pub vars: Vec<TypeVar>,
376 pub preds: Vec<Predicate>,
377 pub typ: Type,
378}
379
380impl Scheme {
381 pub fn new(vars: Vec<TypeVar>, preds: Vec<Predicate>, typ: Type) -> Self {
382 Self { vars, preds, typ }
383 }
384}
385
386pub type Subst = HashTrieMapSync<TypeVarId, Type>;
387
388pub trait Types: Sized {
389 fn apply(&self, s: &Subst) -> Self;
390 fn ftv(&self) -> HashSet<TypeVarId>;
391}
392
393impl Types for Type {
394 fn apply(&self, s: &Subst) -> Self {
395 self.apply_with_change(s).0
396 }
397
398 fn ftv(&self) -> HashSet<TypeVarId> {
399 let mut out = HashSet::new();
400 let mut stack: Vec<&Type> = vec![self];
401 while let Some(t) = stack.pop() {
402 match t.as_ref() {
403 TypeKind::Var(tv) => {
404 out.insert(tv.id);
405 }
406 TypeKind::Con(_) => {}
407 TypeKind::App(l, r) => {
408 stack.push(l);
409 stack.push(r);
410 }
411 TypeKind::Fun(a, b) => {
412 stack.push(a);
413 stack.push(b);
414 }
415 TypeKind::Tuple(ts) => {
416 for t in ts {
417 stack.push(t);
418 }
419 }
420 TypeKind::Record(fields) => {
421 for (_, ty) in fields {
422 stack.push(ty);
423 }
424 }
425 }
426 }
427 out
428 }
429}
430
431impl Types for Predicate {
432 fn apply(&self, s: &Subst) -> Self {
433 Predicate {
434 class: self.class.clone(),
435 typ: self.typ.apply(s),
436 }
437 }
438
439 fn ftv(&self) -> HashSet<TypeVarId> {
440 self.typ.ftv()
441 }
442}
443
444impl Types for Scheme {
445 fn apply(&self, s: &Subst) -> Self {
446 let mut s_pruned = Subst::new_sync();
447 for (k, v) in s.iter() {
448 if !self.vars.iter().any(|var| var.id == *k) {
449 s_pruned = s_pruned.insert(*k, v.clone());
450 }
451 }
452 Scheme::new(
453 self.vars.clone(),
454 self.preds.iter().map(|p| p.apply(&s_pruned)).collect(),
455 self.typ.apply(&s_pruned),
456 )
457 }
458
459 fn ftv(&self) -> HashSet<TypeVarId> {
460 let mut ftv = self.typ.ftv();
461 for p in &self.preds {
462 ftv.extend(p.ftv());
463 }
464 for v in &self.vars {
465 ftv.remove(&v.id);
466 }
467 ftv
468 }
469}
470
471impl<T: Types> Types for Vec<T> {
472 fn apply(&self, s: &Subst) -> Self {
473 self.iter().map(|t| t.apply(s)).collect()
474 }
475
476 fn ftv(&self) -> HashSet<TypeVarId> {
477 self.iter().flat_map(Types::ftv).collect()
478 }
479}
480
481#[derive(Clone, Debug, PartialEq)]
482pub struct TypedExpr {
483 pub typ: Type,
484 pub kind: TypedExprKind,
485}
486
487impl TypedExpr {
488 pub fn new(typ: Type, kind: TypedExprKind) -> Self {
489 Self { typ, kind }
490 }
491
492 pub fn apply(&self, s: &Subst) -> Self {
493 match &self.kind {
494 TypedExprKind::Lam { .. } => {
495 let mut params: Vec<(Symbol, Type)> = Vec::new();
496 let mut cur = self;
497 while let TypedExprKind::Lam { param, body } = &cur.kind {
498 params.push((param.clone(), cur.typ.apply(s)));
499 cur = body.as_ref();
500 }
501 let mut out = cur.apply(s);
502 for (param, typ) in params.into_iter().rev() {
503 out = TypedExpr {
504 typ,
505 kind: TypedExprKind::Lam {
506 param,
507 body: Box::new(out),
508 },
509 };
510 }
511 return out;
512 }
513 TypedExprKind::App(..) => {
514 let mut apps: Vec<(Type, &TypedExpr)> = Vec::new();
515 let mut cur = self;
516 while let TypedExprKind::App(f, x) = &cur.kind {
517 apps.push((cur.typ.apply(s), x.as_ref()));
518 cur = f.as_ref();
519 }
520 let mut out = cur.apply(s);
521 for (typ, arg) in apps.into_iter().rev() {
522 out = TypedExpr {
523 typ,
524 kind: TypedExprKind::App(Box::new(out), Box::new(arg.apply(s))),
525 };
526 }
527 return out;
528 }
529 _ => {}
530 }
531
532 let typ = self.typ.apply(s);
533 let kind = match &self.kind {
534 TypedExprKind::Bool(v) => TypedExprKind::Bool(*v),
535 TypedExprKind::Uint(v) => TypedExprKind::Uint(*v),
536 TypedExprKind::Int(v) => TypedExprKind::Int(*v),
537 TypedExprKind::Float(v) => TypedExprKind::Float(*v),
538 TypedExprKind::String(v) => TypedExprKind::String(v.clone()),
539 TypedExprKind::Uuid(v) => TypedExprKind::Uuid(*v),
540 TypedExprKind::DateTime(v) => TypedExprKind::DateTime(*v),
541 TypedExprKind::Hole => TypedExprKind::Hole,
542 TypedExprKind::Tuple(elems) => {
543 TypedExprKind::Tuple(elems.iter().map(|e| e.apply(s)).collect())
544 }
545 TypedExprKind::List(elems) => {
546 TypedExprKind::List(elems.iter().map(|e| e.apply(s)).collect())
547 }
548 TypedExprKind::Dict(kvs) => {
549 let mut out = BTreeMap::new();
550 for (k, v) in kvs {
551 out.insert(k.clone(), v.apply(s));
552 }
553 TypedExprKind::Dict(out)
554 }
555 TypedExprKind::RecordUpdate { base, updates } => {
556 let mut out = BTreeMap::new();
557 for (k, v) in updates {
558 out.insert(k.clone(), v.apply(s));
559 }
560 TypedExprKind::RecordUpdate {
561 base: Box::new(base.apply(s)),
562 updates: out,
563 }
564 }
565 TypedExprKind::Var { name, overloads } => TypedExprKind::Var {
566 name: name.clone(),
567 overloads: overloads.iter().map(|t| t.apply(s)).collect(),
568 },
569 TypedExprKind::App(f, x) => {
570 TypedExprKind::App(Box::new(f.apply(s)), Box::new(x.apply(s)))
571 }
572 TypedExprKind::Project { expr, field } => TypedExprKind::Project {
573 expr: Box::new(expr.apply(s)),
574 field: field.clone(),
575 },
576 TypedExprKind::Lam { param, body } => TypedExprKind::Lam {
577 param: param.clone(),
578 body: Box::new(body.apply(s)),
579 },
580 TypedExprKind::Let { name, def, body } => TypedExprKind::Let {
581 name: name.clone(),
582 def: Box::new(def.apply(s)),
583 body: Box::new(body.apply(s)),
584 },
585 TypedExprKind::LetRec { bindings, body } => TypedExprKind::LetRec {
586 bindings: bindings
587 .iter()
588 .map(|(name, def)| (name.clone(), def.apply(s)))
589 .collect(),
590 body: Box::new(body.apply(s)),
591 },
592 TypedExprKind::Ite {
593 cond,
594 then_expr,
595 else_expr,
596 } => TypedExprKind::Ite {
597 cond: Box::new(cond.apply(s)),
598 then_expr: Box::new(then_expr.apply(s)),
599 else_expr: Box::new(else_expr.apply(s)),
600 },
601 TypedExprKind::Match { scrutinee, arms } => TypedExprKind::Match {
602 scrutinee: Box::new(scrutinee.apply(s)),
603 arms: arms.iter().map(|(p, e)| (p.clone(), e.apply(s))).collect(),
604 },
605 };
606 TypedExpr { typ, kind }
607 }
608}
609
610#[derive(Clone, Debug, PartialEq)]
611pub enum TypedExprKind {
612 Bool(bool),
613 Uint(u64),
614 Int(i64),
615 Float(f64),
616 String(String),
617 Uuid(Uuid),
618 DateTime(DateTime<Utc>),
619 Hole,
620 Tuple(Vec<TypedExpr>),
621 List(Vec<TypedExpr>),
622 Dict(BTreeMap<Symbol, TypedExpr>),
623 RecordUpdate {
624 base: Box<TypedExpr>,
625 updates: BTreeMap<Symbol, TypedExpr>,
626 },
627 Var {
628 name: Symbol,
629 overloads: Vec<Type>,
630 },
631 App(Box<TypedExpr>, Box<TypedExpr>),
632 Project {
633 expr: Box<TypedExpr>,
634 field: Symbol,
635 },
636 Lam {
637 param: Symbol,
638 body: Box<TypedExpr>,
639 },
640 Let {
641 name: Symbol,
642 def: Box<TypedExpr>,
643 body: Box<TypedExpr>,
644 },
645 LetRec {
646 bindings: Vec<(Symbol, TypedExpr)>,
647 body: Box<TypedExpr>,
648 },
649 Ite {
650 cond: Box<TypedExpr>,
651 then_expr: Box<TypedExpr>,
652 else_expr: Box<TypedExpr>,
653 },
654 Match {
655 scrutinee: Box<TypedExpr>,
656 arms: Vec<(Pattern, TypedExpr)>,
657 },
658}
659
660pub fn compose_subst(a: Subst, b: Subst) -> Subst {
665 if subst_is_empty(&a) {
666 return b;
667 }
668 if subst_is_empty(&b) {
669 return a;
670 }
671 let mut res = Subst::new_sync();
672 for (k, v) in b.iter() {
673 res = res.insert(*k, v.apply(&a));
674 }
675 for (k, v) in a.iter() {
676 res = res.insert(*k, v.clone());
677 }
678 res
679}
680
681fn subst_is_empty(s: &Subst) -> bool {
682 s.iter().next().is_none()
683}
684
685fn dedup_preds(preds: Vec<Predicate>) -> Vec<Predicate> {
686 let mut seen = HashSet::new();
687 let mut out = Vec::with_capacity(preds.len());
688 for pred in preds {
689 if seen.insert(pred.clone()) {
690 out.push(pred);
691 }
692 }
693 out
694}
695
696fn is_integral_primitive(typ: &Type) -> bool {
697 matches!(
698 typ.as_ref(),
699 TypeKind::Con(TypeConst {
700 builtin_id: Some(
701 BuiltinTypeId::U8
702 | BuiltinTypeId::U16
703 | BuiltinTypeId::U32
704 | BuiltinTypeId::U64
705 | BuiltinTypeId::I8
706 | BuiltinTypeId::I16
707 | BuiltinTypeId::I32
708 | BuiltinTypeId::I64
709 ),
710 ..
711 })
712 )
713}
714
715fn finalize_infer_for_public_api(
716 mut preds: Vec<Predicate>,
717 mut typ: Type,
718) -> Result<(Vec<Predicate>, Type), TypeError> {
719 let mut subst = Subst::new_sync();
720 for pred in &preds {
721 if pred.class.as_ref() == "Integral"
722 && let TypeKind::Var(tv) = pred.typ.as_ref()
723 {
724 subst = subst.insert(tv.id, Type::builtin(BuiltinTypeId::I32));
725 }
726 }
727
728 if !subst_is_empty(&subst) {
729 preds = dedup_preds(preds.apply(&subst));
730 typ = typ.apply(&subst);
731 }
732
733 for pred in &preds {
734 if pred.class.as_ref() != "Integral" {
735 continue;
736 }
737 if matches!(pred.typ.as_ref(), TypeKind::Var(_)) || is_integral_primitive(&pred.typ) {
738 continue;
739 }
740 return Err(TypeError::Unification("i32".into(), pred.typ.to_string()));
741 }
742
743 Ok((preds, typ))
744}
745
746#[derive(Debug, thiserror::Error, PartialEq, Eq)]
747pub enum TypeError {
748 #[error("types do not unify: {0} vs {1}")]
749 Unification(String, String),
750 #[error("occurs check failed for {0} in {1}")]
751 Occurs(TypeVarId, String),
752 #[error("unknown class {0}")]
753 UnknownClass(Symbol),
754 #[error("no instance for {0} {1}")]
755 NoInstance(Symbol, String),
756 #[error("unknown type {0}")]
757 UnknownTypeName(Symbol),
758 #[error("cannot redefine reserved builtin type `{0}`")]
759 ReservedTypeName(Symbol),
760 #[error("duplicate value definition `{0}`")]
761 DuplicateValue(Symbol),
762 #[error("duplicate class definition `{0}`")]
763 DuplicateClass(Symbol),
764 #[error("class `{class}` must have at least one type parameter (got {got})")]
765 InvalidClassArity { class: Symbol, got: usize },
766 #[error("duplicate class method `{0}`")]
767 DuplicateClassMethod(Symbol),
768 #[error("unknown method `{method}` in instance of class `{class}`")]
769 UnknownInstanceMethod { class: Symbol, method: Symbol },
770 #[error("missing implementation of `{method}` for instance of class `{class}`")]
771 MissingInstanceMethod { class: Symbol, method: Symbol },
772 #[error(
773 "instance method `{method}` requires constraint {class} {typ}, but it is not in the instance context"
774 )]
775 MissingInstanceConstraint {
776 method: Symbol,
777 class: Symbol,
778 typ: String,
779 },
780 #[error("unbound variable {0}")]
781 UnknownVar(Symbol),
782 #[error("ambiguous overload for {0}")]
783 AmbiguousOverload(Symbol),
784 #[error("ambiguous type variable(s) {vars:?} in constraints: {constraints}")]
785 AmbiguousTypeVars {
786 vars: Vec<TypeVarId>,
787 constraints: String,
788 },
789 #[error(
790 "kind mismatch for class `{class}`: expected {expected} type argument(s) remaining, got {got} for {typ}"
791 )]
792 KindMismatch {
793 class: Symbol,
794 expected: usize,
795 got: usize,
796 typ: String,
797 },
798 #[error("missing type class constraint(s): {constraints}")]
799 MissingConstraints { constraints: String },
800 #[error("unsupported expression {0}")]
801 UnsupportedExpr(&'static str),
802 #[error("unknown field `{field}` on {typ}")]
803 UnknownField { field: Symbol, typ: String },
804 #[error("field `{field}` is not definitely available on {typ}")]
805 FieldNotKnown { field: Symbol, typ: String },
806 #[error("non-exhaustive match for {typ}: missing {missing:?}")]
807 NonExhaustiveMatch { typ: String, missing: Vec<Symbol> },
808 #[error("at {span}: {error}")]
809 Spanned { span: Span, error: Box<TypeError> },
810 #[error("internal error: {0}")]
811 Internal(String),
812 #[error("{0}")]
813 OutOfGas(#[from] OutOfGas),
814}
815
816fn with_span(span: &Span, err: TypeError) -> TypeError {
817 match err {
818 TypeError::Spanned { .. } => err,
819 other => TypeError::Spanned {
820 span: *span,
821 error: Box::new(other),
822 },
823 }
824}
825
826fn format_constraints_referencing_vars(preds: &[Predicate], vars: &[TypeVarId]) -> String {
827 if vars.is_empty() {
828 return String::new();
829 }
830 let var_set: HashSet<TypeVarId> = vars.iter().copied().collect();
831 let mut parts = Vec::new();
832 for pred in preds {
833 let ftv = pred.ftv();
834 if ftv.iter().any(|v| var_set.contains(v)) {
835 parts.push(format!("{} {}", pred.class, pred.typ));
836 }
837 }
838 if parts.is_empty() {
839 for pred in preds {
841 parts.push(format!("{} {}", pred.class, pred.typ));
842 }
843 }
844 parts.join(", ")
845}
846
847fn reject_ambiguous_scheme(scheme: &Scheme) -> Result<(), TypeError> {
848 let quantified: HashSet<TypeVarId> = scheme.vars.iter().map(|v| v.id).collect();
852 if quantified.is_empty() {
853 return Ok(());
854 }
855
856 let typ_ftv = scheme.typ.ftv();
857 let mut vars = HashSet::new();
858 for pred in &scheme.preds {
859 let TypeKind::Var(tv) = pred.typ.as_ref() else {
860 continue;
861 };
862 if quantified.contains(&tv.id) && !typ_ftv.contains(&tv.id) {
863 vars.insert(tv.id);
864 }
865 }
866
867 if vars.is_empty() {
868 return Ok(());
869 }
870 let mut vars: Vec<TypeVarId> = vars.into_iter().collect();
871 vars.sort_unstable();
872 let constraints = format_constraints_referencing_vars(&scheme.preds, &vars);
873 Err(TypeError::AmbiguousTypeVars { vars, constraints })
874}
875
876fn scheme_compatible(existing: &Scheme, declared: &Scheme) -> bool {
877 let s = match unify(&existing.typ, &declared.typ) {
878 Ok(s) => s,
879 Err(_) => return false,
880 };
881
882 let existing_preds = existing.preds.apply(&s);
883 let declared_preds = declared.preds.apply(&s);
884
885 let mut lhs: Vec<(Symbol, String)> = existing_preds
886 .iter()
887 .map(|p| (p.class.clone(), p.typ.to_string()))
888 .collect();
889 let mut rhs: Vec<(Symbol, String)> = declared_preds
890 .iter()
891 .map(|p| (p.class.clone(), p.typ.to_string()))
892 .collect();
893 lhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
894 rhs.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
895 lhs == rhs
896}
897
898#[derive(Debug)]
899struct Unifier<'g> {
900 subs: Vec<Option<Type>>,
907 gas: Option<&'g mut GasMeter>,
908 max_infer_depth: Option<usize>,
909 infer_depth: usize,
910}
911
912#[derive(Clone, Copy, Debug)]
913pub struct TypeSystemLimits {
914 pub max_infer_depth: Option<usize>,
915}
916
917impl TypeSystemLimits {
918 pub fn unlimited() -> Self {
919 Self {
920 max_infer_depth: None,
921 }
922 }
923
924 pub fn safe_defaults() -> Self {
925 Self {
926 max_infer_depth: Some(4096),
927 }
928 }
929}
930
931impl Default for TypeSystemLimits {
932 fn default() -> Self {
933 Self::safe_defaults()
934 }
935}
936
937fn superclass_closure(class_env: &ClassEnv, given: &[Predicate]) -> Vec<Predicate> {
938 let mut closure: Vec<Predicate> = given.to_vec();
939 let mut i = 0;
940 while i < closure.len() {
941 let p = closure[i].clone();
942 for sup in class_env.supers_of(&p.class) {
943 closure.push(Predicate::new(sup, p.typ.clone()));
944 }
945 i += 1;
946 }
947 closure
948}
949
950fn check_non_ground_predicates_declared(
951 class_env: &ClassEnv,
952 declared: &[Predicate],
953 inferred: &[Predicate],
954) -> Result<(), TypeError> {
955 let closure = superclass_closure(class_env, declared);
959 let closure_keys: HashSet<String> = closure
960 .iter()
961 .map(|p| format!("{} {}", p.class, p.typ))
962 .collect();
963 let mut missing = Vec::new();
964 for pred in inferred {
965 if pred.typ.ftv().is_empty() {
966 continue;
967 }
968 let key = format!("{} {}", pred.class, pred.typ);
969 if !closure_keys.contains(&key) {
970 missing.push(key);
971 }
972 }
973
974 missing.sort();
975 missing.dedup();
976 if missing.is_empty() {
977 return Ok(());
978 }
979 Err(TypeError::MissingConstraints {
980 constraints: missing.join(", "),
981 })
982}
983
984fn type_term_remaining_arity(ty: &Type) -> Option<usize> {
985 match ty.as_ref() {
986 TypeKind::Var(_) => None,
987 TypeKind::Con(tc) => Some(tc.arity),
988 TypeKind::App(l, _) => {
989 let a = type_term_remaining_arity(l)?;
990 Some(a.saturating_sub(1))
991 }
992 TypeKind::Fun(..) | TypeKind::Tuple(..) | TypeKind::Record(..) => Some(0),
993 }
994}
995
996fn max_head_app_arity_for_var(ty: &Type, var_id: TypeVarId) -> usize {
997 let mut max_arity = 0usize;
998 let mut stack: Vec<&Type> = vec![ty];
999 while let Some(t) = stack.pop() {
1000 match t.as_ref() {
1001 TypeKind::Var(_) | TypeKind::Con(_) => {}
1002 TypeKind::App(l, r) => {
1003 let mut head = t;
1005 let mut args = 0usize;
1006 while let TypeKind::App(left, _) = head.as_ref() {
1007 args += 1;
1008 head = left;
1009 }
1010 if let TypeKind::Var(tv) = head.as_ref()
1011 && tv.id == var_id
1012 {
1013 max_arity = max_arity.max(args);
1014 }
1015 stack.push(l);
1016 stack.push(r);
1017 }
1018 TypeKind::Fun(a, b) => {
1019 stack.push(a);
1020 stack.push(b);
1021 }
1022 TypeKind::Tuple(ts) => {
1023 for t in ts {
1024 stack.push(t);
1025 }
1026 }
1027 TypeKind::Record(fields) => {
1028 for (_, t) in fields {
1029 stack.push(t);
1030 }
1031 }
1032 }
1033 }
1034 max_arity
1035}
1036
1037impl<'g> Unifier<'g> {
1038 fn new(max_infer_depth: Option<usize>) -> Self {
1039 Self {
1040 subs: Vec::new(),
1041 gas: None,
1042 max_infer_depth,
1043 infer_depth: 0,
1044 }
1045 }
1046
1047 fn with_gas(gas: &'g mut GasMeter, max_infer_depth: Option<usize>) -> Self {
1048 Self {
1049 subs: Vec::new(),
1050 gas: Some(gas),
1051 max_infer_depth,
1052 infer_depth: 0,
1053 }
1054 }
1055
1056 fn with_infer_depth<T>(
1057 &mut self,
1058 span: Span,
1059 f: impl FnOnce(&mut Self) -> Result<T, TypeError>,
1060 ) -> Result<T, TypeError> {
1061 if let Some(max) = self.max_infer_depth
1062 && self.infer_depth >= max
1063 {
1064 return Err(TypeError::Spanned {
1065 span,
1066 error: Box::new(TypeError::Internal(format!(
1067 "maximum inference depth exceeded (max {max})"
1068 ))),
1069 });
1070 }
1071 self.infer_depth += 1;
1072 let res = f(self);
1073 self.infer_depth = self.infer_depth.saturating_sub(1);
1074 res
1075 }
1076
1077 fn charge_infer_node(&mut self) -> Result<(), TypeError> {
1078 let Some(gas) = self.gas.as_mut() else {
1079 return Ok(());
1080 };
1081 let cost = gas.costs.infer_node;
1082 gas.charge(cost)?;
1083 Ok(())
1084 }
1085
1086 fn charge_unify_step(&mut self) -> Result<(), TypeError> {
1087 let Some(gas) = self.gas.as_mut() else {
1088 return Ok(());
1089 };
1090 let cost = gas.costs.unify_step;
1091 gas.charge(cost)?;
1092 Ok(())
1093 }
1094
1095 fn bind_var(&mut self, id: TypeVarId, ty: Type) {
1096 if id >= self.subs.len() {
1097 self.subs.resize(id + 1, None);
1098 }
1099 self.subs[id] = Some(ty);
1100 }
1101
1102 fn prune(&mut self, ty: &Type) -> Type {
1103 match ty.as_ref() {
1104 TypeKind::Var(tv) => {
1105 let bound = self.subs.get(tv.id).and_then(|t| t.clone());
1106 match bound {
1107 Some(bound) => {
1108 let pruned = self.prune(&bound);
1109 self.bind_var(tv.id, pruned.clone());
1110 pruned
1111 }
1112 None => ty.clone(),
1113 }
1114 }
1115 TypeKind::Con(_) => ty.clone(),
1116 TypeKind::App(l, r) => {
1117 let l = self.prune(l);
1118 let r = self.prune(r);
1119 Type::app(l, r)
1120 }
1121 TypeKind::Fun(a, b) => {
1122 let a = self.prune(a);
1123 let b = self.prune(b);
1124 Type::fun(a, b)
1125 }
1126 TypeKind::Tuple(ts) => {
1127 Type::new(TypeKind::Tuple(ts.iter().map(|t| self.prune(t)).collect()))
1128 }
1129 TypeKind::Record(fields) => Type::new(TypeKind::Record(
1130 fields
1131 .iter()
1132 .map(|(name, ty)| (name.clone(), self.prune(ty)))
1133 .collect(),
1134 )),
1135 }
1136 }
1137
1138 fn apply_type(&mut self, ty: &Type) -> Type {
1139 self.prune(ty)
1140 }
1141
1142 fn occurs(&mut self, id: TypeVarId, ty: &Type) -> bool {
1143 match self.prune(ty).as_ref() {
1144 TypeKind::Var(tv) => tv.id == id,
1145 TypeKind::Con(_) => false,
1146 TypeKind::App(l, r) => self.occurs(id, l) || self.occurs(id, r),
1147 TypeKind::Fun(a, b) => self.occurs(id, a) || self.occurs(id, b),
1148 TypeKind::Tuple(ts) => ts.iter().any(|t| self.occurs(id, t)),
1149 TypeKind::Record(fields) => fields.iter().any(|(_, ty)| self.occurs(id, ty)),
1150 }
1151 }
1152
1153 fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
1154 self.charge_unify_step()?;
1155 let t1 = self.prune(t1);
1156 let t2 = self.prune(t2);
1157 match (t1.as_ref(), t2.as_ref()) {
1158 (TypeKind::Var(a), TypeKind::Var(b)) if a.id == b.id => Ok(()),
1159 (TypeKind::Var(tv), other) | (other, TypeKind::Var(tv)) => {
1160 if self.occurs(tv.id, &Type::new(other.clone())) {
1161 Err(TypeError::Occurs(
1162 tv.id,
1163 Type::new(other.clone()).to_string(),
1164 ))
1165 } else {
1166 self.bind_var(tv.id, Type::new(other.clone()));
1167 Ok(())
1168 }
1169 }
1170 (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(()),
1171 (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1172 self.unify(l1, l2)?;
1173 self.unify(r1, r2)
1174 }
1175 (TypeKind::Fun(a1, b1), TypeKind::Fun(a2, b2)) => {
1176 self.unify(a1, a2)?;
1177 self.unify(b1, b2)
1178 }
1179 (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1180 if ts1.len() != ts2.len() {
1181 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1182 }
1183 for (a, b) in ts1.iter().zip(ts2.iter()) {
1184 self.unify(a, b)?;
1185 }
1186 Ok(())
1187 }
1188 (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1189 if f1.len() != f2.len() {
1190 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1191 }
1192 for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1193 if n1 != n2 {
1194 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1195 }
1196 self.unify(t1, t2)?;
1197 }
1198 Ok(())
1199 }
1200 (TypeKind::Record(fields), TypeKind::App(head, arg))
1201 | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1202 TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1203 let elem_ty = record_elem_type_unifier(fields, self)?;
1204 self.unify(arg, &elem_ty)
1205 }
1206 TypeKind::Var(tv) => {
1207 self.unify(
1208 &Type::new(TypeKind::Var(tv.clone())),
1209 &Type::builtin(BuiltinTypeId::Dict),
1210 )?;
1211 let elem_ty = record_elem_type_unifier(fields, self)?;
1212 self.unify(arg, &elem_ty)
1213 }
1214 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1215 },
1216 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1217 }
1218 }
1219
1220 fn into_subst(mut self) -> Subst {
1221 let mut out = Subst::new_sync();
1222 for id in 0..self.subs.len() {
1223 if let Some(ty) = self.subs[id].clone() {
1224 let pruned = self.prune(&ty);
1225 out = out.insert(id, pruned);
1226 }
1227 }
1228 out
1229 }
1230}
1231
1232fn record_elem_type_unifier(
1233 fields: &[(Symbol, Type)],
1234 unifier: &mut Unifier<'_>,
1235) -> Result<Type, TypeError> {
1236 let mut iter = fields.iter();
1237 let first = match iter.next() {
1238 Some((_, ty)) => ty.clone(),
1239 None => return Err(TypeError::UnsupportedExpr("empty record")),
1240 };
1241 for (_, ty) in iter {
1242 unifier.unify(&first, ty)?;
1243 }
1244 Ok(unifier.apply_type(&first))
1245}
1246
1247fn bind(tv: &TypeVar, t: &Type) -> Result<Subst, TypeError> {
1248 if let TypeKind::Var(var) = t.as_ref()
1249 && var.id == tv.id
1250 {
1251 return Ok(Subst::new_sync());
1252 }
1253 if t.ftv().contains(&tv.id) {
1254 Err(TypeError::Occurs(tv.id, t.to_string()))
1255 } else {
1256 Ok(Subst::new_sync().insert(tv.id, t.clone()))
1257 }
1258}
1259
1260fn record_elem_type(fields: &[(Symbol, Type)]) -> Result<(Subst, Type), TypeError> {
1261 let mut iter = fields.iter();
1262 let first = match iter.next() {
1263 Some((_, ty)) => ty.clone(),
1264 None => return Err(TypeError::UnsupportedExpr("empty record")),
1265 };
1266 let mut subst = Subst::new_sync();
1267 let mut current = first;
1268 for (_, ty) in iter {
1269 let s_next = unify(¤t.apply(&subst), &ty.apply(&subst))?;
1270 subst = compose_subst(s_next, subst);
1271 current = current.apply(&subst);
1272 }
1273 Ok((subst.clone(), current.apply(&subst)))
1274}
1275
1276pub fn unify(t1: &Type, t2: &Type) -> Result<Subst, TypeError> {
1283 match (t1.as_ref(), t2.as_ref()) {
1284 (TypeKind::Fun(l1, r1), TypeKind::Fun(l2, r2)) => {
1285 let s1 = unify(l1, l2)?;
1286 let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1287 Ok(compose_subst(s2, s1))
1288 }
1289 (TypeKind::Record(f1), TypeKind::Record(f2)) => {
1290 if f1.len() != f2.len() {
1291 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1292 }
1293 let mut subst = Subst::new_sync();
1294 for ((n1, t1), (n2, t2)) in f1.iter().zip(f2.iter()) {
1295 if n1 != n2 {
1296 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1297 }
1298 let s_next = unify(&t1.apply(&subst), &t2.apply(&subst))?;
1299 subst = compose_subst(s_next, subst);
1300 }
1301 Ok(subst)
1302 }
1303 (TypeKind::Record(fields), TypeKind::App(head, arg))
1304 | (TypeKind::App(head, arg), TypeKind::Record(fields)) => match head.as_ref() {
1305 TypeKind::Con(c) if c.builtin_id == Some(BuiltinTypeId::Dict) => {
1306 let (s_fields, elem_ty) = record_elem_type(fields)?;
1307 let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1308 Ok(compose_subst(s_arg, s_fields))
1309 }
1310 TypeKind::Var(tv) => {
1311 let s_head = bind(tv, &Type::builtin(BuiltinTypeId::Dict))?;
1312 let arg = arg.apply(&s_head);
1313 let (s_fields, elem_ty) = record_elem_type(fields)?;
1314 let s_arg = unify(&arg.apply(&s_fields), &elem_ty)?;
1315 Ok(compose_subst(s_arg, compose_subst(s_fields, s_head)))
1316 }
1317 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1318 },
1319 (TypeKind::App(l1, r1), TypeKind::App(l2, r2)) => {
1320 let s1 = unify(l1, l2)?;
1321 let s2 = unify(&r1.apply(&s1), &r2.apply(&s1))?;
1322 Ok(compose_subst(s2, s1))
1323 }
1324 (TypeKind::Tuple(ts1), TypeKind::Tuple(ts2)) => {
1325 if ts1.len() != ts2.len() {
1326 return Err(TypeError::Unification(t1.to_string(), t2.to_string()));
1327 }
1328 let mut s = Subst::new_sync();
1329 for (a, b) in ts1.iter().zip(ts2.iter()) {
1330 let s_next = unify(&a.apply(&s), &b.apply(&s))?;
1331 s = compose_subst(s_next, s);
1332 }
1333 Ok(s)
1334 }
1335 (TypeKind::Var(tv), t) | (t, TypeKind::Var(tv)) => bind(tv, &Type::new(t.clone())),
1336 (TypeKind::Con(c1), TypeKind::Con(c2)) if c1 == c2 => Ok(Subst::new_sync()),
1337 _ => Err(TypeError::Unification(t1.to_string(), t2.to_string())),
1338 }
1339}
1340
1341#[derive(Default, Debug, Clone)]
1342pub struct TypeEnv {
1343 pub values: HashTrieMapSync<Symbol, Vec<Scheme>>,
1344}
1345
1346impl TypeEnv {
1347 pub fn new() -> Self {
1348 Self {
1349 values: HashTrieMapSync::new_sync(),
1350 }
1351 }
1352
1353 pub fn extend(&mut self, name: Symbol, scheme: Scheme) {
1354 self.values = self.values.insert(name, vec![scheme]);
1355 }
1356
1357 pub fn extend_overload(&mut self, name: Symbol, scheme: Scheme) {
1358 let mut schemes = self.values.get(&name).cloned().unwrap_or_default();
1359 schemes.push(scheme);
1360 self.values = self.values.insert(name, schemes);
1361 }
1362
1363 pub fn remove(&mut self, name: &Symbol) {
1364 self.values = self.values.remove(name);
1365 }
1366
1367 pub fn lookup(&self, name: &Symbol) -> Option<&[Scheme]> {
1368 self.values.get(name).map(|schemes| schemes.as_slice())
1369 }
1370}
1371
1372impl Types for TypeEnv {
1373 fn apply(&self, s: &Subst) -> Self {
1374 let mut values = HashTrieMapSync::new_sync();
1375 for (k, v) in self.values.iter() {
1376 let updated = v
1377 .iter()
1378 .map(|scheme| {
1379 if scheme.vars.is_empty() && !subst_is_empty(s) {
1382 scheme.apply(s)
1383 } else {
1384 scheme.clone()
1385 }
1386 })
1387 .collect();
1388 values = values.insert(k.clone(), updated);
1389 }
1390 TypeEnv { values }
1391 }
1392
1393 fn ftv(&self) -> HashSet<TypeVarId> {
1394 self.values
1395 .iter()
1396 .flat_map(|(_, schemes)| schemes.iter().flat_map(Types::ftv))
1397 .collect()
1398 }
1399}
1400
1401#[derive(Clone, Debug)]
1402struct KnownVariant {
1403 adt: Symbol,
1404 variant: Symbol,
1405}
1406
1407type KnownVariants = HashMap<Symbol, KnownVariant>;
1408
1409#[derive(Default, Debug, Clone)]
1410pub struct TypeVarSupply {
1411 counter: TypeVarId,
1412}
1413
1414impl TypeVarSupply {
1415 pub fn new() -> Self {
1416 Self { counter: 0 }
1417 }
1418
1419 pub fn fresh(&mut self, name_hint: impl Into<Option<Symbol>>) -> TypeVar {
1420 let tv = TypeVar::new(self.counter, name_hint.into());
1421 self.counter += 1;
1422 tv
1423 }
1424}
1425
1426fn apply_scheme_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> Scheme {
1427 let preds = scheme
1428 .preds
1429 .iter()
1430 .map(|pred| Predicate::new(pred.class.clone(), unifier.apply_type(&pred.typ)))
1431 .collect();
1432 let typ = unifier.apply_type(&scheme.typ);
1433 Scheme::new(scheme.vars.clone(), preds, typ)
1434}
1435
1436fn scheme_ftv_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> HashSet<TypeVarId> {
1437 let mut ftv = unifier.apply_type(&scheme.typ).ftv();
1438 for pred in &scheme.preds {
1439 ftv.extend(unifier.apply_type(&pred.typ).ftv());
1440 }
1441 for var in &scheme.vars {
1442 ftv.remove(&var.id);
1443 }
1444 ftv
1445}
1446
1447fn env_ftv_with_unifier(env: &TypeEnv, unifier: &mut Unifier<'_>) -> HashSet<TypeVarId> {
1448 let mut out = HashSet::new();
1449 for (_name, schemes) in env.values.iter() {
1450 for scheme in schemes {
1451 out.extend(scheme_ftv_with_unifier(scheme, unifier));
1452 }
1453 }
1454 out
1455}
1456
1457fn generalize_with_unifier(
1458 env: &TypeEnv,
1459 preds: Vec<Predicate>,
1460 typ: Type,
1461 unifier: &mut Unifier<'_>,
1462) -> Scheme {
1463 let preds: Vec<Predicate> = preds
1467 .into_iter()
1468 .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
1469 .collect();
1470 let typ = unifier.apply_type(&typ);
1471 let mut vars: Vec<TypeVar> = typ
1472 .ftv()
1473 .union(&preds.ftv())
1474 .copied()
1475 .collect::<HashSet<_>>()
1476 .difference(&env_ftv_with_unifier(env, unifier))
1477 .cloned()
1478 .map(|id| TypeVar::new(id, None))
1479 .collect();
1480 vars.sort_by_key(|v| v.id);
1481 Scheme::new(vars, preds, typ)
1482}
1483
1484fn monomorphic_scheme_with_unifier(
1485 preds: Vec<Predicate>,
1486 typ: Type,
1487 unifier: &mut Unifier<'_>,
1488) -> Scheme {
1489 let preds = dedup_preds(
1490 preds
1491 .into_iter()
1492 .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
1493 .collect(),
1494 );
1495 let typ = unifier.apply_type(&typ);
1496 Scheme::new(vec![], preds, typ)
1497}
1498
1499fn is_integral_literal_expr(expr: &Expr) -> bool {
1500 matches!(expr, Expr::Int(..) | Expr::Uint(..))
1501}
1502
1503pub fn generalize(env: &TypeEnv, preds: Vec<Predicate>, typ: Type) -> Scheme {
1506 let mut vars: Vec<TypeVar> = typ
1507 .ftv()
1508 .union(&preds.ftv())
1509 .copied()
1510 .collect::<HashSet<_>>()
1511 .difference(&env.ftv())
1512 .cloned()
1513 .map(|id| TypeVar::new(id, None))
1514 .collect();
1515 vars.sort_by_key(|v| v.id);
1516 Scheme::new(vars, preds, typ)
1517}
1518
1519pub fn instantiate(scheme: &Scheme, supply: &mut TypeVarSupply) -> (Vec<Predicate>, Type) {
1520 let mut subst = Subst::new_sync();
1523 for v in &scheme.vars {
1524 subst = subst.insert(v.id, Type::var(supply.fresh(v.name.clone())));
1525 }
1526 (scheme.preds.apply(&subst), scheme.typ.apply(&subst))
1527}
1528
1529#[derive(Clone, Debug)]
1531pub struct AdtParam {
1532 pub name: Symbol,
1533 pub var: TypeVar,
1534}
1535
1536#[derive(Clone, Debug)]
1538pub struct AdtVariant {
1539 pub name: Symbol,
1540 pub args: Vec<Type>,
1541}
1542
1543#[derive(Clone, Debug)]
1549pub struct AdtDecl {
1550 pub name: Symbol,
1551 pub params: Vec<AdtParam>,
1552 pub variants: Vec<AdtVariant>,
1553}
1554
1555impl AdtDecl {
1556 pub fn new(name: &Symbol, param_names: &[Symbol], supply: &mut TypeVarSupply) -> Self {
1557 let params = param_names
1558 .iter()
1559 .map(|p| AdtParam {
1560 name: p.clone(),
1561 var: supply.fresh(Some(p.clone())),
1562 })
1563 .collect();
1564 Self {
1565 name: name.clone(),
1566 params,
1567 variants: Vec::new(),
1568 }
1569 }
1570
1571 pub fn param_type(&self, name: &Symbol) -> Option<Type> {
1572 self.params
1573 .iter()
1574 .find(|p| &p.name == name)
1575 .map(|p| Type::var(p.var.clone()))
1576 }
1577
1578 pub fn add_variant(&mut self, name: Symbol, args: Vec<Type>) {
1579 self.variants.push(AdtVariant { name, args });
1580 }
1581
1582 pub fn result_type(&self) -> Type {
1583 let mut ty = Type::con(&self.name, self.params.len());
1584 for param in &self.params {
1585 ty = Type::app(ty, Type::var(param.var.clone()));
1586 }
1587 ty
1588 }
1589
1590 pub fn constructor_schemes(&self) -> Vec<(Symbol, Scheme)> {
1593 let result_ty = self.result_type();
1594 let vars: Vec<TypeVar> = self.params.iter().map(|p| p.var.clone()).collect();
1595 let mut out = Vec::new();
1596 for variant in &self.variants {
1597 let mut typ = result_ty.clone();
1598 for arg in variant.args.iter().rev() {
1599 typ = Type::fun(arg.clone(), typ);
1600 }
1601 out.push((variant.name.clone(), Scheme::new(vars.clone(), vec![], typ)));
1602 }
1603 out
1604 }
1605}
1606
1607#[derive(Clone, Debug)]
1608pub struct Class {
1609 pub supers: Vec<Symbol>,
1610}
1611
1612impl Class {
1613 pub fn new(supers: Vec<Symbol>) -> Self {
1614 Self { supers }
1615 }
1616}
1617
1618#[derive(Clone, Debug)]
1619pub struct Instance {
1620 pub context: Vec<Predicate>,
1621 pub head: Predicate,
1622}
1623
1624impl Instance {
1625 pub fn new(context: Vec<Predicate>, head: Predicate) -> Self {
1626 Self { context, head }
1627 }
1628}
1629
1630#[derive(Default, Debug, Clone)]
1631pub struct ClassEnv {
1632 pub classes: HashMap<Symbol, Class>,
1633 pub instances: HashMap<Symbol, Vec<Instance>>,
1634}
1635
1636impl ClassEnv {
1637 pub fn new() -> Self {
1638 Self {
1639 classes: HashMap::new(),
1640 instances: HashMap::new(),
1641 }
1642 }
1643
1644 pub fn add_class(&mut self, name: Symbol, supers: Vec<Symbol>) {
1645 self.classes.insert(name, Class::new(supers));
1646 }
1647
1648 pub fn add_instance(&mut self, class: Symbol, inst: Instance) {
1649 self.instances.entry(class).or_default().push(inst);
1650 }
1651
1652 pub fn supers_of(&self, class: &Symbol) -> Vec<Symbol> {
1653 self.classes
1654 .get(class)
1655 .map(|c| c.supers.clone())
1656 .unwrap_or_default()
1657 }
1658}
1659
1660pub fn entails(
1661 class_env: &ClassEnv,
1662 given: &[Predicate],
1663 pred: &Predicate,
1664) -> Result<bool, TypeError> {
1665 let mut closure: Vec<Predicate> = given.to_vec();
1667 let mut i = 0;
1668 while i < closure.len() {
1669 let p = closure[i].clone();
1670 for sup in class_env.supers_of(&p.class) {
1671 closure.push(Predicate::new(sup, p.typ.clone()));
1672 }
1673 i += 1;
1674 }
1675
1676 if closure
1677 .iter()
1678 .any(|p| p.class == pred.class && p.typ == pred.typ)
1679 {
1680 return Ok(true);
1681 }
1682
1683 if !class_env.classes.contains_key(&pred.class) {
1684 return Err(TypeError::UnknownClass(pred.class.clone()));
1685 }
1686
1687 if let Some(instances) = class_env.instances.get(&pred.class) {
1688 for inst in instances {
1689 if let Ok(s) = unify(&inst.head.typ, &pred.typ) {
1690 let ctx = inst.context.apply(&s);
1691 if ctx
1692 .iter()
1693 .all(|c| entails(class_env, &closure, c).unwrap_or(false))
1694 {
1695 return Ok(true);
1696 }
1697 }
1698 }
1699 }
1700 Ok(false)
1701}
1702
1703#[derive(Default, Debug, Clone)]
1704pub struct TypeSystem {
1705 pub env: TypeEnv,
1706 pub classes: ClassEnv,
1707 pub adts: HashMap<Symbol, AdtDecl>,
1708 pub class_info: HashMap<Symbol, ClassInfo>,
1709 pub class_methods: HashMap<Symbol, ClassMethodInfo>,
1710 pub declared_values: HashSet<Symbol>,
1715 pub supply: TypeVarSupply,
1716 limits: TypeSystemLimits,
1717}
1718
1719#[derive(Clone, Debug)]
1729pub struct ClassInfo {
1730 pub name: Symbol,
1731 pub params: Vec<Symbol>,
1732 pub supers: Vec<Symbol>,
1733 pub methods: BTreeMap<Symbol, Scheme>,
1734}
1735
1736#[derive(Clone, Debug)]
1737pub struct ClassMethodInfo {
1738 pub class: Symbol,
1739 pub scheme: Scheme,
1740}
1741
1742#[derive(Clone, Debug)]
1743pub struct PreparedInstanceDecl {
1744 pub span: Span,
1745 pub class: Symbol,
1746 pub head: Type,
1747 pub context: Vec<Predicate>,
1748}
1749
1750impl TypeSystem {
1751 pub fn new() -> Self {
1752 Self {
1753 env: TypeEnv::new(),
1754 classes: ClassEnv::new(),
1755 adts: HashMap::new(),
1756 class_info: HashMap::new(),
1757 class_methods: HashMap::new(),
1758 declared_values: HashSet::new(),
1759 supply: TypeVarSupply::new(),
1760 limits: TypeSystemLimits::default(),
1761 }
1762 }
1763
1764 pub fn fresh_type_var(&mut self, name: Option<Symbol>) -> TypeVar {
1765 self.supply.fresh(name)
1766 }
1767
1768 pub fn set_limits(&mut self, limits: TypeSystemLimits) {
1769 self.limits = limits;
1770 }
1771
1772 pub fn with_prelude() -> Result<Self, TypeError> {
1773 let mut ts = TypeSystem::new();
1774 prelude::build_prelude(&mut ts)?;
1775 Ok(ts)
1776 }
1777
1778 pub fn inject_decl(&mut self, decl: &Decl) -> Result<(), TypeError> {
1779 match decl {
1780 Decl::Type(ty) => self.inject_type_decl(ty),
1781 Decl::Class(class_decl) => self.inject_class_decl(class_decl),
1782 Decl::Instance(inst_decl) => {
1783 let _ = self.inject_instance_decl(inst_decl)?;
1784 Ok(())
1785 }
1786 Decl::Fn(fd) => self.inject_fn_decls(std::slice::from_ref(fd)),
1787 Decl::DeclareFn(fd) => self.inject_declare_fn_decl(fd),
1788 Decl::Import(..) => Ok(()),
1789 }
1790 }
1791
1792 pub fn inject_decls(&mut self, decls: &[Decl]) -> Result<(), TypeError> {
1793 let mut pending_fns: Vec<FnDecl> = Vec::new();
1794 for decl in decls {
1795 if let Decl::Fn(fd) = decl {
1796 pending_fns.push(fd.clone());
1797 continue;
1798 }
1799
1800 if !pending_fns.is_empty() {
1801 self.inject_fn_decls(&pending_fns)?;
1802 pending_fns.clear();
1803 }
1804
1805 self.inject_decl(decl)?;
1806 }
1807 if !pending_fns.is_empty() {
1808 self.inject_fn_decls(&pending_fns)?;
1809 }
1810 Ok(())
1811 }
1812
1813 pub fn add_value(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1814 let name = sym(name.as_ref());
1815 self.declared_values.remove(&name);
1816 self.env.extend(name, scheme);
1817 }
1818
1819 pub fn add_overload(&mut self, name: impl AsRef<str>, scheme: Scheme) {
1820 let name = sym(name.as_ref());
1821 self.declared_values.remove(&name);
1822 self.env.extend_overload(name, scheme);
1823 }
1824
1825 pub fn inject_class(&mut self, name: impl AsRef<str>, supers: Vec<Symbol>) {
1826 self.classes.add_class(sym(name.as_ref()), supers);
1827 }
1828
1829 pub fn inject_instance(&mut self, class: impl AsRef<str>, inst: Instance) {
1830 self.classes.add_instance(sym(class.as_ref()), inst);
1831 }
1832
1833 pub fn inject_class_decl(&mut self, decl: &ClassDecl) -> Result<(), TypeError> {
1834 let span = decl.span;
1835 (|| {
1836 if self.class_info.contains_key(&decl.name)
1840 || self.classes.classes.contains_key(&decl.name)
1841 {
1842 return Err(TypeError::DuplicateClass(decl.name.clone()));
1843 }
1844 if decl.params.is_empty() {
1845 return Err(TypeError::InvalidClassArity {
1846 class: decl.name.clone(),
1847 got: decl.params.len(),
1848 });
1849 }
1850 let params = decl.params.clone();
1851
1852 let mut supers = Vec::with_capacity(decl.supers.len());
1858 if !decl.supers.is_empty() && params.len() != 1 {
1859 return Err(TypeError::UnsupportedExpr(
1860 "multi-parameter classes cannot declare superclasses yet",
1861 ));
1862 }
1863 for sup in &decl.supers {
1864 let mut vars = HashMap::new();
1865 let param = params[0].clone();
1866 let param_tv = self.supply.fresh(Some(param.clone()));
1867 vars.insert(param, param_tv.clone());
1868 let sup_ty = type_from_annotation_expr_vars(
1869 &self.adts,
1870 &sup.typ,
1871 &mut vars,
1872 &mut self.supply,
1873 )?;
1874 if sup_ty != Type::var(param_tv) {
1875 return Err(TypeError::UnsupportedExpr(
1876 "superclass constraints must be of the form `<= C a`",
1877 ));
1878 }
1879 supers.push(sup.class.to_dotted_symbol());
1880 }
1881
1882 self.classes.add_class(decl.name.clone(), supers.clone());
1883
1884 let mut methods = BTreeMap::new();
1885 for ClassMethodSig { name, typ } in &decl.methods {
1886 if self.env.lookup(name).is_some() || self.class_methods.contains_key(name) {
1887 return Err(TypeError::DuplicateClassMethod(name.clone()));
1888 }
1889
1890 let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1891 let mut param_tvs: Vec<TypeVar> = Vec::with_capacity(params.len());
1892 for param in ¶ms {
1893 let tv = self.supply.fresh(Some(param.clone()));
1894 vars.insert(param.clone(), tv.clone());
1895 param_tvs.push(tv);
1896 }
1897
1898 let ty =
1899 type_from_annotation_expr_vars(&self.adts, typ, &mut vars, &mut self.supply)?;
1900
1901 let mut scheme_vars: Vec<TypeVar> = vars.values().cloned().collect();
1902 scheme_vars.sort_by_key(|tv| tv.id);
1903 scheme_vars.dedup_by_key(|tv| tv.id);
1904
1905 let class_pred = Predicate {
1906 class: decl.name.clone(),
1907 typ: if param_tvs.len() == 1 {
1908 Type::var(param_tvs[0].clone())
1909 } else {
1910 Type::tuple(param_tvs.into_iter().map(Type::var).collect())
1911 },
1912 };
1913 let scheme = Scheme::new(scheme_vars, vec![class_pred], ty);
1914
1915 self.env.extend(name.clone(), scheme.clone());
1916 self.class_methods.insert(
1917 name.clone(),
1918 ClassMethodInfo {
1919 class: decl.name.clone(),
1920 scheme: scheme.clone(),
1921 },
1922 );
1923 methods.insert(name.clone(), scheme);
1924 }
1925
1926 self.class_info.insert(
1927 decl.name.clone(),
1928 ClassInfo {
1929 name: decl.name.clone(),
1930 params,
1931 supers,
1932 methods,
1933 },
1934 );
1935 Ok(())
1936 })()
1937 .map_err(|err| with_span(&span, err))
1938 }
1939
1940 pub fn inject_instance_decl(
1941 &mut self,
1942 decl: &InstanceDecl,
1943 ) -> Result<PreparedInstanceDecl, TypeError> {
1944 let span = decl.span;
1945 (|| {
1946 let class = decl.class.clone();
1947 if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
1948 return Err(TypeError::UnknownClass(class));
1949 }
1950
1951 let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
1952 let head = type_from_annotation_expr_vars(
1953 &self.adts,
1954 &decl.head,
1955 &mut vars,
1956 &mut self.supply,
1957 )?;
1958 let context = predicates_from_constraints(
1959 &self.adts,
1960 &decl.context,
1961 &mut vars,
1962 &mut self.supply,
1963 )?;
1964
1965 let inst = Instance::new(
1966 context.clone(),
1967 Predicate {
1968 class: decl.class.clone(),
1969 typ: head.clone(),
1970 },
1971 );
1972
1973 if let Some(info) = self.class_info.get(&decl.class) {
1975 for method in &decl.methods {
1976 if !info.methods.contains_key(&method.name) {
1977 return Err(TypeError::UnknownInstanceMethod {
1978 class: decl.class.clone(),
1979 method: method.name.clone(),
1980 });
1981 }
1982 }
1983 for method_name in info.methods.keys() {
1984 if !decl.methods.iter().any(|m| &m.name == method_name) {
1985 return Err(TypeError::MissingInstanceMethod {
1986 class: decl.class.clone(),
1987 method: method_name.clone(),
1988 });
1989 }
1990 }
1991 }
1992
1993 self.classes.add_instance(decl.class.clone(), inst);
1994 Ok(PreparedInstanceDecl {
1995 span,
1996 class: decl.class.clone(),
1997 head,
1998 context,
1999 })
2000 })()
2001 .map_err(|err| with_span(&span, err))
2002 }
2003
2004 pub fn prepare_instance_decl(
2005 &mut self,
2006 decl: &InstanceDecl,
2007 ) -> Result<PreparedInstanceDecl, TypeError> {
2008 let span = decl.span;
2009 (|| {
2010 let class = decl.class.clone();
2011 if !self.class_info.contains_key(&class) && !self.classes.classes.contains_key(&class) {
2012 return Err(TypeError::UnknownClass(class));
2013 }
2014
2015 let mut vars: HashMap<Symbol, TypeVar> = HashMap::new();
2016 let head = type_from_annotation_expr_vars(
2017 &self.adts,
2018 &decl.head,
2019 &mut vars,
2020 &mut self.supply,
2021 )?;
2022 let context = predicates_from_constraints(
2023 &self.adts,
2024 &decl.context,
2025 &mut vars,
2026 &mut self.supply,
2027 )?;
2028
2029 if let Some(info) = self.class_info.get(&decl.class) {
2031 for method in &decl.methods {
2032 if !info.methods.contains_key(&method.name) {
2033 return Err(TypeError::UnknownInstanceMethod {
2034 class: decl.class.clone(),
2035 method: method.name.clone(),
2036 });
2037 }
2038 }
2039 for method_name in info.methods.keys() {
2040 if !decl.methods.iter().any(|m| &m.name == method_name) {
2041 return Err(TypeError::MissingInstanceMethod {
2042 class: decl.class.clone(),
2043 method: method_name.clone(),
2044 });
2045 }
2046 }
2047 }
2048
2049 Ok(PreparedInstanceDecl {
2050 span,
2051 class: decl.class.clone(),
2052 head,
2053 context,
2054 })
2055 })()
2056 .map_err(|err| with_span(&span, err))
2057 }
2058
2059 pub fn inject_fn_decl(&mut self, decl: &FnDecl) -> Result<(), TypeError> {
2060 self.inject_fn_decls(std::slice::from_ref(decl))
2061 }
2062
2063 pub fn inject_fn_decls(&mut self, decls: &[FnDecl]) -> Result<(), TypeError> {
2064 if decls.is_empty() {
2065 return Ok(());
2066 }
2067
2068 let saved_env = self.env.clone();
2069 let saved_declared = self.declared_values.clone();
2070
2071 let result: Result<(), TypeError> = (|| {
2072 #[derive(Clone)]
2073 struct FnInfo {
2074 decl: FnDecl,
2075 expected: Type,
2076 declared_preds: Vec<Predicate>,
2077 scheme: Scheme,
2078 ann_vars: HashMap<Symbol, TypeVar>,
2079 }
2080
2081 let mut infos: Vec<FnInfo> = Vec::with_capacity(decls.len());
2082 let mut seen_names = HashSet::new();
2083
2084 for decl in decls {
2085 let span = decl.span;
2086 let info = (|| {
2087 let name = &decl.name.name;
2088 if !seen_names.insert(name.clone()) {
2089 return Err(TypeError::DuplicateValue(name.clone()));
2090 }
2091
2092 if self.env.lookup(name).is_some() {
2093 if self.declared_values.remove(name) {
2094 self.env.remove(name);
2096 } else {
2097 return Err(TypeError::DuplicateValue(name.clone()));
2098 }
2099 }
2100
2101 let mut sig = decl.ret.clone();
2102 for (_, ann) in decl.params.iter().rev() {
2103 let span = Span::from_begin_end(ann.span().begin, sig.span().end);
2104 sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
2105 }
2106
2107 let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
2108 let expected = type_from_annotation_expr_vars(
2109 &self.adts,
2110 &sig,
2111 &mut ann_vars,
2112 &mut self.supply,
2113 )?;
2114 let declared_preds = predicates_from_constraints(
2115 &self.adts,
2116 &decl.constraints,
2117 &mut ann_vars,
2118 &mut self.supply,
2119 )?;
2120
2121 let var_arities: HashMap<TypeVarId, usize> = ann_vars
2123 .values()
2124 .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
2125 .collect();
2126 for pred in &declared_preds {
2127 let _ = entails(&self.classes, &[], pred)?;
2128 let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
2129 else {
2130 continue;
2131 };
2132 let args: Vec<Type> = if expected_arities.len() == 1 {
2133 vec![pred.typ.clone()]
2134 } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2135 if parts.len() != expected_arities.len() {
2136 continue;
2137 }
2138 parts.clone()
2139 } else {
2140 continue;
2141 };
2142
2143 for (arg, expected_arity) in
2144 args.iter().zip(expected_arities.iter().copied())
2145 {
2146 let got =
2147 type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2148 TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2149 _ => None,
2150 });
2151 let Some(got) = got else {
2152 continue;
2153 };
2154 if got != expected_arity {
2155 return Err(TypeError::KindMismatch {
2156 class: pred.class.clone(),
2157 expected: expected_arity,
2158 got,
2159 typ: arg.to_string(),
2160 });
2161 }
2162 }
2163 }
2164
2165 let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2166 vars.sort_by_key(|v| v.id);
2167 let scheme = Scheme::new(vars, declared_preds.clone(), expected.clone());
2168 reject_ambiguous_scheme(&scheme)?;
2169
2170 Ok(FnInfo {
2171 decl: decl.clone(),
2172 expected,
2173 declared_preds,
2174 scheme,
2175 ann_vars,
2176 })
2177 })();
2178
2179 infos.push(info.map_err(|err| with_span(&span, err))?);
2180 }
2181
2182 for info in &infos {
2185 self.env
2186 .extend(info.decl.name.name.clone(), info.scheme.clone());
2187 }
2188
2189 for info in infos {
2190 let span = info.decl.span;
2191 let mut lam_body = info.decl.body.clone();
2192 let mut lam_end = lam_body.span().end;
2193 for (param, ann) in info.decl.params.iter().rev() {
2194 let lam_constraints = Vec::new();
2195 let span = Span::from_begin_end(param.span.begin, lam_end);
2196 lam_body = Arc::new(Expr::Lam(
2197 span,
2198 Scope::new_sync(),
2199 param.clone(),
2200 Some(ann.clone()),
2201 lam_constraints,
2202 lam_body,
2203 ));
2204 lam_end = lam_body.span().end;
2205 }
2206
2207 let (typed, preds, inferred) = self.infer_typed(lam_body.as_ref())?;
2208 let s = unify(&inferred, &info.expected)?;
2209 let preds = preds.apply(&s);
2210 let inferred = inferred.apply(&s);
2211 let declared_preds = info.declared_preds.apply(&s);
2212 let expected = info.expected.apply(&s);
2213
2214 let var_arities: HashMap<TypeVarId, usize> = info
2216 .ann_vars
2217 .values()
2218 .map(|tv| (tv.id, max_head_app_arity_for_var(&expected, tv.id)))
2219 .collect();
2220 for pred in &declared_preds {
2221 let _ = entails(&self.classes, &[], pred)?;
2222 let Some(expected_arities) = self.expected_class_param_arities(&pred.class)
2223 else {
2224 continue;
2225 };
2226 let args: Vec<Type> = if expected_arities.len() == 1 {
2227 vec![pred.typ.clone()]
2228 } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2229 if parts.len() != expected_arities.len() {
2230 continue;
2231 }
2232 parts.clone()
2233 } else {
2234 continue;
2235 };
2236
2237 for (arg, expected_arity) in args.iter().zip(expected_arities.iter().copied()) {
2238 let got = type_term_remaining_arity(arg).or_else(|| match arg.as_ref() {
2239 TypeKind::Var(tv) => var_arities.get(&tv.id).copied(),
2240 _ => None,
2241 });
2242 let Some(got) = got else {
2243 continue;
2244 };
2245 if got != expected_arity {
2246 return Err(with_span(
2247 &span,
2248 TypeError::KindMismatch {
2249 class: pred.class.clone(),
2250 expected: expected_arity,
2251 got,
2252 typ: arg.to_string(),
2253 },
2254 ));
2255 }
2256 }
2257 }
2258
2259 check_non_ground_predicates_declared(&self.classes, &declared_preds, &preds)
2260 .map_err(|err| with_span(&span, err))?;
2261
2262 let _ = inferred;
2263 let _ = typed;
2264 }
2265
2266 Ok(())
2267 })();
2268
2269 if result.is_err() {
2270 self.env = saved_env;
2271 self.declared_values = saved_declared;
2272 }
2273 result
2274 }
2275
2276 pub fn inject_declare_fn_decl(&mut self, decl: &DeclareFnDecl) -> Result<(), TypeError> {
2277 let span = decl.span;
2278 (|| {
2279 let mut sig = decl.ret.clone();
2281 for (_, ann) in decl.params.iter().rev() {
2282 let span = Span::from_begin_end(ann.span().begin, sig.span().end);
2283 sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
2284 }
2285
2286 let mut ann_vars: HashMap<Symbol, TypeVar> = HashMap::new();
2287 let expected =
2288 type_from_annotation_expr_vars(&self.adts, &sig, &mut ann_vars, &mut self.supply)?;
2289 let declared_preds = predicates_from_constraints(
2290 &self.adts,
2291 &decl.constraints,
2292 &mut ann_vars,
2293 &mut self.supply,
2294 )?;
2295
2296 let mut vars: Vec<TypeVar> = ann_vars.values().cloned().collect();
2297 vars.sort_by_key(|v| v.id);
2298 let scheme = Scheme::new(vars, declared_preds, expected);
2299 reject_ambiguous_scheme(&scheme)?;
2300
2301 for pred in &scheme.preds {
2303 let _ = entails(&self.classes, &[], pred)?;
2304 }
2305
2306 let name = &decl.name.name;
2307
2308 if self.env.lookup(name).is_some() && !self.declared_values.contains(name) {
2311 return Ok(());
2312 }
2313
2314 if let Some(existing) = self.env.lookup(name) {
2315 if existing.iter().any(|s| scheme_compatible(s, &scheme)) {
2316 return Ok(());
2317 }
2318 return Err(TypeError::DuplicateValue(decl.name.name.clone()));
2319 }
2320
2321 self.env.extend(decl.name.name.clone(), scheme);
2322 self.declared_values.insert(decl.name.name.clone());
2323 Ok(())
2324 })()
2325 .map_err(|err| with_span(&span, err))
2326 }
2327
2328 pub fn instantiate_class_method_for_head(
2329 &mut self,
2330 class: &Symbol,
2331 method: &Symbol,
2332 head: &Type,
2333 ) -> Result<Type, TypeError> {
2334 let info = self
2335 .class_info
2336 .get(class)
2337 .ok_or_else(|| TypeError::UnknownClass(class.clone()))?;
2338 let scheme = info
2339 .methods
2340 .get(method)
2341 .ok_or_else(|| TypeError::UnknownInstanceMethod {
2342 class: class.clone(),
2343 method: method.clone(),
2344 })?;
2345
2346 let (preds, typ) = instantiate(scheme, &mut self.supply);
2347 let class_pred =
2348 preds
2349 .iter()
2350 .find(|p| &p.class == class)
2351 .ok_or(TypeError::UnsupportedExpr(
2352 "class method scheme missing class predicate",
2353 ))?;
2354 let s = unify(&class_pred.typ, head)?;
2355 Ok(typ.apply(&s))
2356 }
2357
2358 pub fn typecheck_instance_method(
2359 &mut self,
2360 prepared: &PreparedInstanceDecl,
2361 method: &InstanceMethodImpl,
2362 ) -> Result<TypedExpr, TypeError> {
2363 let expected =
2364 self.instantiate_class_method_for_head(&prepared.class, &method.name, &prepared.head)?;
2365 let (typed, preds, actual) = self.infer_typed(method.body.as_ref())?;
2366 let s = unify(&actual, &expected)?;
2367 let typed = typed.apply(&s);
2368 let preds = preds.apply(&s);
2369
2370 let mut given = prepared.context.clone();
2376
2377 given.push(Predicate::new(
2380 prepared.class.clone(),
2381 prepared.head.clone(),
2382 ));
2383 let mut i = 0;
2384 while i < given.len() {
2385 let p = given[i].clone();
2386 for sup in self.classes.supers_of(&p.class) {
2387 given.push(Predicate::new(sup, p.typ.clone()));
2388 }
2389 i += 1;
2390 }
2391
2392 for pred in &preds {
2393 if pred.typ.ftv().is_empty() {
2394 if !entails(&self.classes, &given, pred)? {
2395 return Err(TypeError::NoInstance(
2396 pred.class.clone(),
2397 pred.typ.to_string(),
2398 ));
2399 }
2400 } else if !given
2401 .iter()
2402 .any(|p| p.class == pred.class && p.typ == pred.typ)
2403 {
2404 return Err(TypeError::MissingInstanceConstraint {
2405 method: method.name.clone(),
2406 class: pred.class.clone(),
2407 typ: pred.typ.to_string(),
2408 });
2409 }
2410 }
2411
2412 Ok(typed)
2413 }
2414
2415 pub fn inject_adt(&mut self, adt: &AdtDecl) {
2419 self.adts.insert(adt.name.clone(), adt.clone());
2420 for (name, scheme) in adt.constructor_schemes() {
2421 self.register_value_scheme(&name, scheme);
2422 }
2423 }
2424
2425 pub fn adt_from_decl(&mut self, decl: &TypeDecl) -> Result<AdtDecl, TypeError> {
2426 let mut adt = AdtDecl::new(&decl.name, &decl.params, &mut self.supply);
2427 let mut param_map: HashMap<Symbol, TypeVar> = HashMap::new();
2428 for param in &adt.params {
2429 param_map.insert(param.name.clone(), param.var.clone());
2430 }
2431
2432 for variant in &decl.variants {
2433 let mut args = Vec::new();
2434 for arg in &variant.args {
2435 let ty = self.type_from_expr(decl, ¶m_map, arg)?;
2436 args.push(ty);
2437 }
2438 adt.add_variant(variant.name.clone(), args);
2439 }
2440 Ok(adt)
2441 }
2442
2443 pub fn inject_type_decl(&mut self, decl: &TypeDecl) -> Result<(), TypeError> {
2444 if BuiltinTypeId::from_symbol(&decl.name).is_some() {
2445 return Err(TypeError::ReservedTypeName(decl.name.clone()));
2446 }
2447 let adt = self.adt_from_decl(decl)?;
2448 self.inject_adt(&adt);
2449 Ok(())
2450 }
2451
2452 fn type_from_expr(
2453 &mut self,
2454 decl: &TypeDecl,
2455 params: &HashMap<Symbol, TypeVar>,
2456 expr: &TypeExpr,
2457 ) -> Result<Type, TypeError> {
2458 let span = *expr.span();
2459 let res = (|| match expr {
2460 TypeExpr::Name(_, name) => {
2461 let name_sym = name.to_dotted_symbol();
2462 if let Some(tv) = params.get(&name_sym) {
2463 Ok(Type::var(tv.clone()))
2464 } else {
2465 let name = normalize_type_name(&name_sym);
2466 if let Some(arity) = self.type_arity(decl, &name) {
2467 Ok(Type::con(name, arity))
2468 } else {
2469 Err(TypeError::UnknownTypeName(name))
2470 }
2471 }
2472 }
2473 TypeExpr::App(_, fun, arg) => {
2474 let fty = self.type_from_expr(decl, params, fun)?;
2475 let aty = self.type_from_expr(decl, params, arg)?;
2476 Ok(type_app_with_result_syntax(fty, aty))
2477 }
2478 TypeExpr::Fun(_, arg, ret) => {
2479 let arg_ty = self.type_from_expr(decl, params, arg)?;
2480 let ret_ty = self.type_from_expr(decl, params, ret)?;
2481 Ok(Type::fun(arg_ty, ret_ty))
2482 }
2483 TypeExpr::Tuple(_, elems) => {
2484 let mut out = Vec::new();
2485 for elem in elems {
2486 out.push(self.type_from_expr(decl, params, elem)?);
2487 }
2488 Ok(Type::tuple(out))
2489 }
2490 TypeExpr::Record(_, fields) => {
2491 let mut out = Vec::new();
2492 for (name, ty) in fields {
2493 out.push((name.clone(), self.type_from_expr(decl, params, ty)?));
2494 }
2495 Ok(Type::record(out))
2496 }
2497 })();
2498 res.map_err(|err| with_span(&span, err))
2499 }
2500
2501 fn type_arity(&self, decl: &TypeDecl, name: &Symbol) -> Option<usize> {
2502 if &decl.name == name {
2503 return Some(decl.params.len());
2504 }
2505 if let Some(adt) = self.adts.get(name) {
2506 return Some(adt.params.len());
2507 }
2508 BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2509 }
2510
2511 fn register_value_scheme(&mut self, name: &Symbol, scheme: Scheme) {
2512 match self.env.lookup(name) {
2513 None => self.env.extend(name.clone(), scheme),
2514 Some(existing) => {
2515 if existing.iter().any(|s| unify(&s.typ, &scheme.typ).is_ok()) {
2516 return;
2517 }
2518 self.env.extend_overload(name.clone(), scheme);
2519 }
2520 }
2521 }
2522
2523 pub fn infer_typed(
2524 &mut self,
2525 expr: &Expr,
2526 ) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
2527 self.infer_typed_inner(expr)
2528 }
2529
2530 pub fn infer_typed_with_gas(
2531 &mut self,
2532 expr: &Expr,
2533 gas: &mut GasMeter,
2534 ) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
2535 let known = KnownVariants::new();
2536 let mut unifier = Unifier::with_gas(gas, self.limits.max_infer_depth);
2537 let (preds, t, typed) = infer_expr(
2538 &mut unifier,
2539 &mut self.supply,
2540 &self.env,
2541 &self.adts,
2542 &known,
2543 expr,
2544 )
2545 .map_err(|err| with_span(expr.span(), err))?;
2546 let subst = unifier.into_subst();
2547 let mut typed = typed.apply(&subst);
2548 let mut preds = dedup_preds(preds.apply(&subst));
2549 let mut t = t.apply(&subst);
2550 let improve = improve_indexable(&preds)?;
2551 if !subst_is_empty(&improve) {
2552 typed = typed.apply(&improve);
2553 preds = dedup_preds(preds.apply(&improve));
2554 t = t.apply(&improve);
2555 }
2556 self.check_predicate_kinds(&preds)?;
2557 Ok((typed, preds, t))
2558 }
2559
2560 fn infer_typed_inner(
2561 &mut self,
2562 expr: &Expr,
2563 ) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
2564 let known = KnownVariants::new();
2565 let mut unifier = Unifier::new(self.limits.max_infer_depth);
2566 let (preds, t, typed) = infer_expr(
2567 &mut unifier,
2568 &mut self.supply,
2569 &self.env,
2570 &self.adts,
2571 &known,
2572 expr,
2573 )
2574 .map_err(|err| with_span(expr.span(), err))?;
2575 let subst = unifier.into_subst();
2576 let mut typed = typed.apply(&subst);
2577 let mut preds = dedup_preds(preds.apply(&subst));
2578 let mut t = t.apply(&subst);
2579 let improve = improve_indexable(&preds)?;
2580 if !subst_is_empty(&improve) {
2581 typed = typed.apply(&improve);
2582 preds = dedup_preds(preds.apply(&improve));
2583 t = t.apply(&improve);
2584 }
2585 self.check_predicate_kinds(&preds)?;
2586 Ok((typed, preds, t))
2587 }
2588
2589 pub fn infer(&mut self, expr: &Expr) -> Result<(Vec<Predicate>, Type), TypeError> {
2590 self.infer_inner(expr)
2591 }
2592
2593 pub fn infer_with_gas(
2594 &mut self,
2595 expr: &Expr,
2596 gas: &mut GasMeter,
2597 ) -> Result<(Vec<Predicate>, Type), TypeError> {
2598 let known = KnownVariants::new();
2599 let mut unifier = Unifier::with_gas(gas, self.limits.max_infer_depth);
2600 let (preds, t) = infer_expr_type(
2601 &mut unifier,
2602 &mut self.supply,
2603 &self.env,
2604 &self.adts,
2605 &known,
2606 expr,
2607 )
2608 .map_err(|err| with_span(expr.span(), err))?;
2609 let subst = unifier.into_subst();
2610 let preds = dedup_preds(preds.apply(&subst));
2611 let t = t.apply(&subst);
2612 self.check_predicate_kinds(&preds)?;
2613 finalize_infer_for_public_api(preds, t)
2614 }
2615
2616 fn infer_inner(&mut self, expr: &Expr) -> Result<(Vec<Predicate>, Type), TypeError> {
2617 let known = KnownVariants::new();
2618 let mut unifier = Unifier::new(self.limits.max_infer_depth);
2619 let (preds, t) = infer_expr_type(
2620 &mut unifier,
2621 &mut self.supply,
2622 &self.env,
2623 &self.adts,
2624 &known,
2625 expr,
2626 )
2627 .map_err(|err| with_span(expr.span(), err))?;
2628 let subst = unifier.into_subst();
2629 let mut preds = dedup_preds(preds.apply(&subst));
2630 let mut t = t.apply(&subst);
2631 let improve = improve_indexable(&preds)?;
2632 if !subst_is_empty(&improve) {
2633 preds = dedup_preds(preds.apply(&improve));
2634 t = t.apply(&improve);
2635 }
2636 self.check_predicate_kinds(&preds)?;
2637 finalize_infer_for_public_api(preds, t)
2638 }
2639
2640 fn expected_class_param_arities(&self, class: &Symbol) -> Option<Vec<usize>> {
2641 let info = self.class_info.get(class)?;
2642 let mut out = vec![0usize; info.params.len()];
2643 for scheme in info.methods.values() {
2644 for (idx, param) in info.params.iter().enumerate() {
2645 let Some(tv) = scheme.vars.iter().find(|v| v.name.as_ref() == Some(param)) else {
2646 continue;
2647 };
2648 out[idx] = out[idx].max(max_head_app_arity_for_var(&scheme.typ, tv.id));
2649 }
2650 }
2651 Some(out)
2652 }
2653
2654 fn check_predicate_kind(&self, pred: &Predicate) -> Result<(), TypeError> {
2655 let Some(expected) = self.expected_class_param_arities(&pred.class) else {
2656 return Ok(());
2658 };
2659
2660 let args: Vec<Type> = if expected.len() == 1 {
2661 vec![pred.typ.clone()]
2662 } else if let TypeKind::Tuple(parts) = pred.typ.as_ref() {
2663 if parts.len() != expected.len() {
2664 return Ok(());
2665 }
2666 parts.clone()
2667 } else {
2668 return Ok(());
2669 };
2670
2671 for (arg, expected_arity) in args.iter().zip(expected.iter().copied()) {
2672 let Some(got) = type_term_remaining_arity(arg) else {
2673 continue;
2677 };
2678 if got != expected_arity {
2679 return Err(TypeError::KindMismatch {
2680 class: pred.class.clone(),
2681 expected: expected_arity,
2682 got,
2683 typ: arg.to_string(),
2684 });
2685 }
2686 }
2687 Ok(())
2688 }
2689
2690 fn check_predicate_kinds(&self, preds: &[Predicate]) -> Result<(), TypeError> {
2691 for pred in preds {
2692 self.check_predicate_kind(pred)?;
2693 }
2694 Ok(())
2695 }
2696}
2697
2698fn improve_indexable(preds: &[Predicate]) -> Result<Subst, TypeError> {
2699 let mut subst = Subst::new_sync();
2700 loop {
2701 let mut changed = false;
2702 for pred in preds {
2703 let pred = pred.apply(&subst);
2704 if pred.class.as_ref() != "Indexable" {
2705 continue;
2706 }
2707 let TypeKind::Tuple(parts) = pred.typ.as_ref() else {
2708 continue;
2709 };
2710 if parts.len() != 2 {
2711 continue;
2712 }
2713 let container = parts[0].clone();
2714 let elem = parts[1].clone();
2715 let s = indexable_elem_subst(&container, &elem)?;
2716 if !subst_is_empty(&s) {
2717 subst = compose_subst(s, subst);
2718 changed = true;
2719 }
2720 }
2721 if !changed {
2722 break;
2723 }
2724 }
2725 Ok(subst)
2726}
2727
2728fn indexable_elem_subst(container: &Type, elem: &Type) -> Result<Subst, TypeError> {
2729 match container.as_ref() {
2730 TypeKind::App(head, arg) => match head.as_ref() {
2731 TypeKind::Con(tc)
2732 if matches!(
2733 tc.builtin_id,
2734 Some(BuiltinTypeId::List | BuiltinTypeId::Array)
2735 ) =>
2736 {
2737 unify(elem, arg)
2738 }
2739 _ => Ok(Subst::new_sync()),
2740 },
2741 TypeKind::Tuple(elems) => {
2742 if elems.is_empty() {
2743 return Ok(Subst::new_sync());
2744 }
2745 let mut subst = Subst::new_sync();
2746 let mut cur = elems[0].clone();
2747 for ty in elems.iter().skip(1) {
2748 let s_next = unify(&cur.apply(&subst), &ty.apply(&subst))?;
2749 subst = compose_subst(s_next, subst);
2750 cur = cur.apply(&subst);
2751 }
2752 let elem = elem.apply(&subst);
2753 let s_elem = unify(&elem, &cur.apply(&subst))?;
2754 Ok(compose_subst(s_elem, subst))
2755 }
2756 _ => Ok(Subst::new_sync()),
2757 }
2758}
2759
2760fn type_from_annotation_expr(
2761 adts: &HashMap<Symbol, AdtDecl>,
2762 expr: &TypeExpr,
2763) -> Result<Type, TypeError> {
2764 let span = *expr.span();
2765 let res = (|| match expr {
2766 TypeExpr::Name(_, name) => {
2767 let name = normalize_type_name(&name.to_dotted_symbol());
2768 match annotation_type_arity(adts, &name) {
2769 Some(arity) => Ok(Type::con(name, arity)),
2770 None => Err(TypeError::UnknownTypeName(name)),
2771 }
2772 }
2773 TypeExpr::App(_, fun, arg) => {
2774 let fty = type_from_annotation_expr(adts, fun)?;
2775 let aty = type_from_annotation_expr(adts, arg)?;
2776 Ok(type_app_with_result_syntax(fty, aty))
2777 }
2778 TypeExpr::Fun(_, arg, ret) => {
2779 let arg_ty = type_from_annotation_expr(adts, arg)?;
2780 let ret_ty = type_from_annotation_expr(adts, ret)?;
2781 Ok(Type::fun(arg_ty, ret_ty))
2782 }
2783 TypeExpr::Tuple(_, elems) => {
2784 let mut out = Vec::new();
2785 for elem in elems {
2786 out.push(type_from_annotation_expr(adts, elem)?);
2787 }
2788 Ok(Type::tuple(out))
2789 }
2790 TypeExpr::Record(_, fields) => {
2791 let mut out = Vec::new();
2792 for (name, ty) in fields {
2793 out.push((name.clone(), type_from_annotation_expr(adts, ty)?));
2794 }
2795 Ok(Type::record(out))
2796 }
2797 })();
2798 res.map_err(|err| with_span(&span, err))
2799}
2800
2801fn type_from_annotation_expr_vars(
2802 adts: &HashMap<Symbol, AdtDecl>,
2803 expr: &TypeExpr,
2804 vars: &mut HashMap<Symbol, TypeVar>,
2805 supply: &mut TypeVarSupply,
2806) -> Result<Type, TypeError> {
2807 let span = *expr.span();
2808 let res = (|| match expr {
2809 TypeExpr::Name(_, name) => {
2810 let name = normalize_type_name(&name.to_dotted_symbol());
2811 if let Some(arity) = annotation_type_arity(adts, &name) {
2812 Ok(Type::con(name, arity))
2813 } else if let Some(tv) = vars.get(&name) {
2814 Ok(Type::var(tv.clone()))
2815 } else {
2816 let is_upper = name
2817 .chars()
2818 .next()
2819 .map(|c| c.is_uppercase())
2820 .unwrap_or(false);
2821 if is_upper {
2822 return Err(TypeError::UnknownTypeName(name));
2823 }
2824 let tv = supply.fresh(Some(name.clone()));
2825 vars.insert(name.clone(), tv.clone());
2826 Ok(Type::var(tv))
2827 }
2828 }
2829 TypeExpr::App(_, fun, arg) => {
2830 let fty = type_from_annotation_expr_vars(adts, fun, vars, supply)?;
2831 let aty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2832 Ok(type_app_with_result_syntax(fty, aty))
2833 }
2834 TypeExpr::Fun(_, arg, ret) => {
2835 let arg_ty = type_from_annotation_expr_vars(adts, arg, vars, supply)?;
2836 let ret_ty = type_from_annotation_expr_vars(adts, ret, vars, supply)?;
2837 Ok(Type::fun(arg_ty, ret_ty))
2838 }
2839 TypeExpr::Tuple(_, elems) => {
2840 let mut out = Vec::new();
2841 for elem in elems {
2842 out.push(type_from_annotation_expr_vars(adts, elem, vars, supply)?);
2843 }
2844 Ok(Type::tuple(out))
2845 }
2846 TypeExpr::Record(_, fields) => {
2847 let mut out = Vec::new();
2848 for (name, ty) in fields {
2849 out.push((
2850 name.clone(),
2851 type_from_annotation_expr_vars(adts, ty, vars, supply)?,
2852 ));
2853 }
2854 Ok(Type::record(out))
2855 }
2856 })();
2857 res.map_err(|err| with_span(&span, err))
2858}
2859
2860fn annotation_type_arity(adts: &HashMap<Symbol, AdtDecl>, name: &Symbol) -> Option<usize> {
2861 if let Some(adt) = adts.get(name) {
2862 return Some(adt.params.len());
2863 }
2864 BuiltinTypeId::from_symbol(name).map(BuiltinTypeId::arity)
2865}
2866
2867fn normalize_type_name(name: &Symbol) -> Symbol {
2868 if name.as_ref() == "str" {
2869 BuiltinTypeId::String.as_symbol()
2870 } else {
2871 name.clone()
2872 }
2873}
2874
2875fn type_app_with_result_syntax(fun: Type, arg: Type) -> Type {
2876 if let TypeKind::App(head, ok) = fun.as_ref()
2880 && matches!(
2881 head.as_ref(),
2882 TypeKind::Con(c)
2883 if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
2884 )
2885 {
2886 return Type::app(Type::app(head.clone(), arg), ok.clone());
2887 }
2888 Type::app(fun, arg)
2889}
2890
2891type LambdaChain<'a> = (
2892 Vec<(Symbol, Option<TypeExpr>)>,
2893 Vec<TypeConstraint>,
2894 &'a Expr,
2895);
2896
2897fn collect_lambda_chain<'a>(expr: &'a Expr) -> LambdaChain<'a> {
2898 let mut params = Vec::new();
2899 let mut constraints = Vec::new();
2900 let mut cur = expr;
2901 let mut seen_constraints = false;
2902 while let Expr::Lam(_, _scope, param, ann, lam_constraints, body) = cur {
2903 if !lam_constraints.is_empty() {
2904 if seen_constraints {
2905 break;
2906 }
2907 constraints = lam_constraints.clone();
2908 seen_constraints = true;
2909 }
2910 params.push((param.name.clone(), ann.clone()));
2911 cur = body.as_ref();
2912 }
2913 (params, constraints, cur)
2914}
2915
2916fn predicates_from_constraints(
2917 adts: &HashMap<Symbol, AdtDecl>,
2918 constraints: &[TypeConstraint],
2919 vars: &mut HashMap<Symbol, TypeVar>,
2920 supply: &mut TypeVarSupply,
2921) -> Result<Vec<Predicate>, TypeError> {
2922 let mut out = Vec::with_capacity(constraints.len());
2923 for constraint in constraints {
2924 let ty = type_from_annotation_expr_vars(adts, &constraint.typ, vars, supply)?;
2925 out.push(Predicate::new(constraint.class.as_ref(), ty));
2926 }
2927 Ok(out)
2928}
2929
2930fn collect_app_chain(expr: &Expr) -> (&Expr, Vec<&Expr>) {
2931 let mut args = Vec::new();
2932 let mut cur = expr;
2933 while let Expr::App(_, f, x) = cur {
2934 args.push(x.as_ref());
2935 cur = f.as_ref();
2936 }
2937 args.reverse();
2938 (cur, args)
2939}
2940
2941fn narrow_overload_candidates(candidates: &[Type], arg_ty: &Type) -> Vec<Type> {
2942 let mut out = Vec::new();
2943 for candidate in candidates {
2944 let Some((params, ret)) = decompose_fun(candidate, 1) else {
2945 continue;
2946 };
2947 let param = ¶ms[0];
2948 if let Ok(s) = unify(param, arg_ty) {
2949 out.push(ret.apply(&s));
2950 }
2951 }
2952 out
2953}
2954
2955fn unary_app_arg(typ: &Type, ctor_name: &str) -> Option<Type> {
2956 let TypeKind::App(head, arg) = typ.as_ref() else {
2957 return None;
2958 };
2959 let TypeKind::Con(tc) = head.as_ref() else {
2960 return None;
2961 };
2962 (tc.name.as_ref() == ctor_name && tc.arity == 1).then(|| arg.clone())
2963}
2964
2965fn infer_app_arg_type(
2966 unifier: &mut Unifier<'_>,
2967 supply: &mut TypeVarSupply,
2968 env: &TypeEnv,
2969 adts: &HashMap<Symbol, AdtDecl>,
2970 known: &KnownVariants,
2971 arg_hint: Option<Type>,
2972 arg: &Expr,
2973) -> Result<(Vec<Predicate>, Type), TypeError> {
2974 match (arg_hint, arg) {
2975 (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
2976 infer_record_update_type_with_hint(
2977 unifier,
2978 supply,
2979 env,
2980 adts,
2981 known,
2982 base.as_ref(),
2983 updates,
2984 &arg_hint,
2985 )
2986 }
2987 (Some(arg_hint), Expr::Dict(_, kvs))
2988 if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
2989 {
2990 let TypeKind::Record(fields) = arg_hint.as_ref() else {
2991 unreachable!("guarded by matches!")
2992 };
2993 let expected: HashMap<_, _> =
2994 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
2995 let mut seen = HashSet::new();
2996 let mut preds = Vec::new();
2997 for (k, v) in kvs {
2998 let expected_ty = expected
2999 .get(k)
3000 .ok_or_else(|| TypeError::UnknownField {
3001 field: k.clone(),
3002 typ: Type::record(fields.clone()).to_string(),
3003 })?
3004 .clone();
3005 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
3006 unifier.unify(&t1, &expected_ty)?;
3007 preds.extend(p1);
3008 seen.insert(k.clone());
3009 }
3010 for key in expected.keys() {
3011 if !seen.contains(key.as_ref()) {
3012 return Err(TypeError::UnknownField {
3013 field: key.clone(),
3014 typ: Type::record(fields.clone()).to_string(),
3015 });
3016 }
3017 }
3018 let record_ty = Type::record(
3019 fields
3020 .iter()
3021 .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
3022 .collect(),
3023 );
3024 Ok((preds, record_ty))
3025 }
3026 _ => infer_expr_type(unifier, supply, env, adts, known, arg),
3027 }
3028}
3029
3030fn infer_app_arg_typed(
3031 unifier: &mut Unifier<'_>,
3032 supply: &mut TypeVarSupply,
3033 env: &TypeEnv,
3034 adts: &HashMap<Symbol, AdtDecl>,
3035 known: &KnownVariants,
3036 arg_hint: Option<Type>,
3037 arg: &Expr,
3038) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
3039 match (arg_hint, arg) {
3040 (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
3041 infer_record_update_typed_with_hint(
3042 unifier,
3043 supply,
3044 env,
3045 adts,
3046 known,
3047 base.as_ref(),
3048 updates,
3049 &arg_hint,
3050 )
3051 }
3052 (Some(arg_hint), Expr::Dict(_, kvs))
3053 if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
3054 {
3055 let TypeKind::Record(fields) = arg_hint.as_ref() else {
3056 unreachable!("guarded by matches!")
3057 };
3058 let mut preds = Vec::new();
3059 let mut typed_kvs = BTreeMap::new();
3060 let expected: HashMap<_, _> =
3061 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
3062 for (k, v) in kvs {
3063 let expected_ty = expected
3064 .get(k)
3065 .ok_or_else(|| TypeError::UnknownField {
3066 field: k.clone(),
3067 typ: Type::record(fields.clone()).to_string(),
3068 })?
3069 .clone();
3070 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
3071 unifier.unify(&t1, &expected_ty)?;
3072 preds.extend(p1);
3073 typed_kvs.insert(k.clone(), typed_v);
3074 }
3075 for key in expected.keys() {
3076 if !typed_kvs.contains_key(key.as_ref()) {
3077 return Err(TypeError::UnknownField {
3078 field: key.clone(),
3079 typ: Type::record(fields.clone()).to_string(),
3080 });
3081 }
3082 }
3083 let record_ty = Type::record(
3084 fields
3085 .iter()
3086 .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
3087 .collect(),
3088 );
3089 let typed = TypedExpr::new(record_ty.clone(), TypedExprKind::Dict(typed_kvs));
3090 Ok((preds, record_ty, typed))
3091 }
3092 _ => infer_expr(unifier, supply, env, adts, known, arg),
3093 }
3094}
3095
3096#[allow(clippy::too_many_arguments)]
3097fn infer_record_update_type_with_hint(
3098 unifier: &mut Unifier<'_>,
3099 supply: &mut TypeVarSupply,
3100 env: &TypeEnv,
3101 adts: &HashMap<Symbol, AdtDecl>,
3102 known: &KnownVariants,
3103 base: &Expr,
3104 updates: &BTreeMap<Symbol, Arc<Expr>>,
3105 hint_ty: &Type,
3106) -> Result<(Vec<Predicate>, Type), TypeError> {
3107 let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
3108 unifier.unify(&t_base, hint_ty)?;
3109 let base_ty = unifier.apply_type(&t_base);
3110 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
3111 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
3112 let (result_ty, fields) = resolve_record_update(
3113 unifier,
3114 supply,
3115 adts,
3116 &base_ty,
3117 known_variant,
3118 &update_fields,
3119 )?;
3120 let expected: HashMap<_, _> = fields.into_iter().collect();
3121
3122 let mut preds = p_base;
3123 for (k, v) in updates {
3124 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
3125 field: k.clone(),
3126 typ: result_ty.to_string(),
3127 })?;
3128 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
3129 unifier.unify(&t1, expected_ty)?;
3130 preds.extend(p1);
3131 }
3132 Ok((preds, result_ty))
3133}
3134
3135#[allow(clippy::too_many_arguments)]
3136fn infer_record_update_typed_with_hint(
3137 unifier: &mut Unifier<'_>,
3138 supply: &mut TypeVarSupply,
3139 env: &TypeEnv,
3140 adts: &HashMap<Symbol, AdtDecl>,
3141 known: &KnownVariants,
3142 base: &Expr,
3143 updates: &BTreeMap<Symbol, Arc<Expr>>,
3144 hint_ty: &Type,
3145) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
3146 let (p_base, t_base, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
3147 unifier.unify(&t_base, hint_ty)?;
3148 let base_ty = unifier.apply_type(&t_base);
3149 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
3150 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
3151 let (result_ty, fields) = resolve_record_update(
3152 unifier,
3153 supply,
3154 adts,
3155 &base_ty,
3156 known_variant,
3157 &update_fields,
3158 )?;
3159 let expected: HashMap<_, _> = fields.into_iter().collect();
3160
3161 let mut preds = p_base;
3162 let mut typed_updates = BTreeMap::new();
3163 for (k, v) in updates {
3164 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
3165 field: k.clone(),
3166 typ: result_ty.to_string(),
3167 })?;
3168 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
3169 unifier.unify(&t1, expected_ty)?;
3170 preds.extend(p1);
3171 typed_updates.insert(k.clone(), typed_v);
3172 }
3173
3174 let typed = TypedExpr::new(
3175 result_ty.clone(),
3176 TypedExprKind::RecordUpdate {
3177 base: Box::new(typed_base),
3178 updates: typed_updates,
3179 },
3180 );
3181 Ok((preds, result_ty, typed))
3182}
3183
3184fn infer_expr_type(
3185 unifier: &mut Unifier<'_>,
3186 supply: &mut TypeVarSupply,
3187 env: &TypeEnv,
3188 adts: &HashMap<Symbol, AdtDecl>,
3189 known: &KnownVariants,
3190 expr: &Expr,
3191) -> Result<(Vec<Predicate>, Type), TypeError> {
3192 let span = *expr.span();
3193 let res = unifier.with_infer_depth(span, |unifier| {
3194 infer_expr_type_inner(unifier, supply, env, adts, known, expr)
3195 });
3196 res.map_err(|err| with_span(&span, err))
3197}
3198
3199fn infer_expr_type_inner(
3200 unifier: &mut Unifier<'_>,
3201 supply: &mut TypeVarSupply,
3202 env: &TypeEnv,
3203 adts: &HashMap<Symbol, AdtDecl>,
3204 known: &KnownVariants,
3205 expr: &Expr,
3206) -> Result<(Vec<Predicate>, Type), TypeError> {
3207 unifier.charge_infer_node()?;
3208 match expr {
3209 Expr::Bool(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Bool))),
3210 Expr::Uint(_, _) => {
3211 let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
3212 Ok((vec![Predicate::new("Integral", lit_ty.clone())], lit_ty))
3213 }
3214 Expr::Int(_, _) => {
3215 let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
3216 Ok((
3217 vec![
3218 Predicate::new("Integral", lit_ty.clone()),
3219 Predicate::new("AdditiveGroup", lit_ty.clone()),
3220 ],
3221 lit_ty,
3222 ))
3223 }
3224 Expr::Float(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::F32))),
3225 Expr::String(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::String))),
3226 Expr::Uuid(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Uuid))),
3227 Expr::DateTime(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::DateTime))),
3228 Expr::Hole(_) => {
3229 let t = Type::var(supply.fresh(Some(sym("hole"))));
3230 Ok((vec![], t))
3231 }
3232 Expr::Var(var) => {
3233 let schemes = env
3234 .lookup(&var.name)
3235 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
3236 if schemes.len() == 1 {
3237 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
3238 let (preds, t) = instantiate(&scheme, supply);
3239 Ok((preds, t))
3240 } else {
3241 for scheme in schemes {
3242 if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
3243 return Err(TypeError::AmbiguousOverload(var.name.clone()));
3244 }
3245 }
3246 let t = Type::var(supply.fresh(Some(var.name.clone())));
3247 Ok((vec![], t))
3248 }
3249 }
3250 Expr::Lam(..) => {
3251 let (params, constraints, body) = collect_lambda_chain(expr);
3252 let mut ann_vars = HashMap::new();
3253 let mut param_tys = Vec::with_capacity(params.len());
3254 for (name, ann) in ¶ms {
3255 let param_ty = match ann {
3256 Some(ann) => type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?,
3257 None => Type::var(supply.fresh(Some(name.clone()))),
3258 };
3259 param_tys.push((name.clone(), param_ty));
3260 }
3261
3262 let mut env1 = env.clone();
3263 let mut known_body = known.clone();
3264 for (name, param_ty) in ¶m_tys {
3265 env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
3266 known_body.remove(name);
3267 }
3268
3269 let (mut preds, body_ty) =
3270 infer_expr_type(unifier, supply, &env1, adts, &known_body, body)?;
3271 let constraint_preds =
3272 predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
3273 preds.extend(constraint_preds);
3274
3275 let mut fun_ty = unifier.apply_type(&body_ty);
3276 for (_, param_ty) in param_tys.iter().rev() {
3277 fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
3278 }
3279 Ok((preds, fun_ty))
3280 }
3281 Expr::App(..) => {
3282 let (head, args) = collect_app_chain(expr);
3283 let (mut preds, mut func_ty) =
3284 infer_expr_type(unifier, supply, env, adts, known, head)?;
3285 let mut overload_name = None;
3286 let mut overload_candidates = if let Expr::Var(var) = head {
3287 if let Some(schemes) = env.lookup(&var.name) {
3288 if schemes.len() <= 1 {
3289 None
3290 } else {
3291 let mut candidates = Vec::new();
3292 for scheme in schemes {
3293 if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
3294 return Err(TypeError::AmbiguousOverload(var.name.clone()));
3295 }
3296 let scheme = apply_scheme_with_unifier(scheme, unifier);
3297 let (p, typ) = instantiate(&scheme, supply);
3298 if !p.is_empty() {
3299 return Err(TypeError::AmbiguousOverload(var.name.clone()));
3300 }
3301 candidates.push(typ);
3302 }
3303 overload_name = Some(var.name.clone());
3304 Some(candidates)
3305 }
3306 } else {
3307 None
3308 }
3309 } else {
3310 None
3311 };
3312 for arg in args {
3313 let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
3314 TypeKind::Fun(arg, _) => Some(arg.clone()),
3315 _ => None,
3316 };
3317 let (p_arg, arg_ty) =
3318 infer_app_arg_type(unifier, supply, env, adts, known, arg_hint, arg)?;
3319 let arg_ty = unifier.apply_type(&arg_ty);
3320 if let Some(candidates) = overload_candidates.take() {
3321 let candidates = candidates
3322 .into_iter()
3323 .map(|t| unifier.apply_type(&t))
3324 .collect::<Vec<_>>();
3325 let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
3326 if narrowed.is_empty()
3327 && let Some(name) = &overload_name
3328 {
3329 return Err(TypeError::AmbiguousOverload(name.clone()));
3330 }
3331 overload_candidates = Some(narrowed);
3332 }
3333 let res_ty = match overload_candidates.as_ref() {
3334 Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
3335 _ => Type::var(supply.fresh(Some("r".into()))),
3336 };
3337 unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
3338 preds.extend(p_arg);
3339 func_ty = match overload_candidates.as_ref() {
3340 Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
3341 _ => unifier.apply_type(&res_ty),
3342 };
3343 }
3344 Ok((preds, func_ty))
3345 }
3346 Expr::Project(_, base, field) => {
3347 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, base)?;
3348 let base_ty = unifier.apply_type(&t1);
3349 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
3350 let field_ty =
3351 resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
3352 Ok((p1, field_ty))
3353 }
3354 Expr::RecordUpdate(_, base, updates) => {
3355 let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
3356 let base_ty = unifier.apply_type(&t_base);
3357 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
3358 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
3359 let (result_ty, fields) = resolve_record_update(
3360 unifier,
3361 supply,
3362 adts,
3363 &base_ty,
3364 known_variant,
3365 &update_fields,
3366 )?;
3367 let expected: HashMap<_, _> = fields.into_iter().collect();
3368
3369 let mut preds = p_base;
3370 for (k, v) in updates {
3371 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
3372 field: k.clone(),
3373 typ: result_ty.to_string(),
3374 })?;
3375 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
3376 unifier.unify(&t1, expected_ty)?;
3377 preds.extend(p1);
3378 }
3379 Ok((preds, result_ty))
3380 }
3381 Expr::Let(..) => {
3382 let mut bindings = Vec::new();
3383 let mut cur = expr;
3384 while let Expr::Let(_, v, ann, d, b) = cur {
3385 bindings.push((v.clone(), ann.clone(), d.clone()));
3386 cur = b.as_ref();
3387 }
3388
3389 let mut env_cur = env.clone();
3390 let mut known_cur = known.clone();
3391 for (v, ann, d) in bindings {
3392 let (p1, t1) = if let Some(ref ann_expr) = ann {
3393 let mut ann_vars = HashMap::new();
3394 let ann_ty =
3395 type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
3396 match d.as_ref() {
3397 Expr::RecordUpdate(_, base, updates) => infer_record_update_type_with_hint(
3398 unifier,
3399 supply,
3400 &env_cur,
3401 adts,
3402 &known_cur,
3403 base.as_ref(),
3404 updates,
3405 &ann_ty,
3406 )?,
3407 _ => {
3408 let (p1, t1) =
3409 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?;
3410 unifier.unify(&t1, &ann_ty)?;
3411 (p1, t1)
3412 }
3413 }
3414 } else {
3415 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?
3416 };
3417 let def_ty = unifier.apply_type(&t1);
3418 let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
3419 monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
3420 } else {
3421 let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
3422 reject_ambiguous_scheme(&scheme)?;
3423 scheme
3424 };
3425 env_cur.extend(v.name.clone(), scheme);
3426 if let Some(known_variant) =
3427 known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
3428 {
3429 known_cur.insert(
3430 v.name.clone(),
3431 KnownVariant {
3432 adt: known_variant.adt,
3433 variant: known_variant.variant,
3434 },
3435 );
3436 } else {
3437 known_cur.remove(&v.name);
3438 }
3439 }
3440
3441 let (p_body, t_body) =
3442 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, cur)?;
3443 Ok((p_body, t_body))
3444 }
3445 Expr::LetRec(_, bindings, body) => {
3446 let mut env_seed = env.clone();
3447 let mut known_seed = known.clone();
3448 let mut binding_tys = HashMap::new();
3449 for (var, _ann, _def) in bindings {
3450 let tv = Type::var(supply.fresh(Some(var.name.clone())));
3451 env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
3452 known_seed.remove(&var.name);
3453 binding_tys.insert(var.name.clone(), tv);
3454 }
3455
3456 let mut inferred = Vec::with_capacity(bindings.len());
3457 for (var, ann, def) in bindings {
3458 let (preds, def_ty) =
3459 infer_expr_type(unifier, supply, &env_seed, adts, &known_seed, def)?;
3460 if let Some(ann) = ann {
3461 let mut ann_vars = HashMap::new();
3462 let ann_ty = type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
3463 unifier.unify(&def_ty, &ann_ty)?;
3464 }
3465 let binding_ty = binding_tys
3466 .get(&var.name)
3467 .cloned()
3468 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
3469 unifier.unify(&binding_ty, &def_ty)?;
3470 let resolved_ty = unifier.apply_type(&binding_ty);
3471
3472 if let Some(known_variant) =
3473 known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
3474 {
3475 known_seed.insert(
3476 var.name.clone(),
3477 KnownVariant {
3478 adt: known_variant.adt,
3479 variant: known_variant.variant,
3480 },
3481 );
3482 } else {
3483 known_seed.remove(&var.name);
3484 }
3485 inferred.push((var.name.clone(), preds, resolved_ty));
3486 }
3487
3488 let mut env_body = env.clone();
3489 for (name, preds, def_ty) in inferred {
3490 let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
3491 reject_ambiguous_scheme(&scheme)?;
3492 env_body.extend(name, scheme);
3493 }
3494
3495 let (p_body, t_body) =
3496 infer_expr_type(unifier, supply, &env_body, adts, &known_seed, body)?;
3497 Ok((p_body, t_body))
3498 }
3499 Expr::Ite(_, cond, then_expr, else_expr) => {
3500 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, cond)?;
3501 unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
3502 let (p2, t2) = infer_expr_type(unifier, supply, env, adts, known, then_expr)?;
3503 let (p3, t3) = infer_expr_type(unifier, supply, env, adts, known, else_expr)?;
3504 unifier.unify(&t2, &t3)?;
3505 let out_ty = unifier.apply_type(&t2);
3506 let mut preds = p1;
3507 preds.extend(p2);
3508 preds.extend(p3);
3509 Ok((preds, out_ty))
3510 }
3511 Expr::Tuple(_, elems) => {
3512 let mut preds = Vec::new();
3513 let mut types = Vec::new();
3514 for elem in elems {
3515 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
3516 preds.extend(p1);
3517 types.push(unifier.apply_type(&t1));
3518 }
3519 let tuple_ty = Type::tuple(types);
3520 Ok((preds, tuple_ty))
3521 }
3522 Expr::List(_, elems) => {
3523 let elem_tv = Type::var(supply.fresh(Some("a".into())));
3524 let mut preds = Vec::new();
3525 for elem in elems {
3526 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
3527 unifier.unify(&t1, &elem_tv)?;
3528 preds.extend(p1);
3529 }
3530 let list_ty = Type::app(
3531 Type::builtin(BuiltinTypeId::List),
3532 unifier.apply_type(&elem_tv),
3533 );
3534 Ok((preds, list_ty))
3535 }
3536 Expr::Dict(_, kvs) => {
3537 let elem_tv = Type::var(supply.fresh(Some("v".into())));
3538 let mut preds = Vec::new();
3539 for v in kvs.values() {
3540 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
3541 unifier.unify(&t1, &elem_tv)?;
3542 preds.extend(p1);
3543 }
3544 let dict_ty = Type::app(
3545 Type::builtin(BuiltinTypeId::Dict),
3546 unifier.apply_type(&elem_tv),
3547 );
3548 Ok((preds, dict_ty))
3549 }
3550 Expr::Match(_, scrutinee, arms) => {
3551 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, scrutinee.as_ref())?;
3552 let mut preds = p1;
3553 let res_ty = Type::var(supply.fresh(Some("match".into())));
3554 let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
3555
3556 for (pat, expr) in arms {
3557 let scrutinee_ty = unifier.apply_type(&t1);
3558 let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
3559 preds.extend(p_pat);
3560
3561 let mut env_arm = env.clone();
3562 for (name, ty) in binds {
3563 env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
3564 }
3565 let mut known_arm = known.clone();
3566 if let Expr::Var(var) = scrutinee.as_ref() {
3567 match pat {
3568 Pattern::Named(_, name, _) => {
3569 let name_sym = name.to_dotted_symbol();
3570 if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
3571 known_arm.insert(
3572 var.name.clone(),
3573 KnownVariant {
3574 adt: adt.name.clone(),
3575 variant: name_sym,
3576 },
3577 );
3578 } else {
3579 known_arm.remove(&var.name);
3580 }
3581 }
3582 _ => {
3583 known_arm.remove(&var.name);
3584 }
3585 }
3586 }
3587 let (p_expr, t_expr) =
3588 infer_expr_type(unifier, supply, &env_arm, adts, &known_arm, expr)?;
3589 unifier.unify(&res_ty, &t_expr)?;
3590 preds.extend(p_expr);
3591 }
3592
3593 let scrutinee_ty = unifier.apply_type(&t1);
3594 check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
3595 let out_ty = unifier.apply_type(&res_ty);
3596 Ok((preds, out_ty))
3597 }
3598 Expr::Ann(_, expr, ann) => {
3599 let ann_ty = type_from_annotation_expr(adts, ann)?;
3600 match expr.as_ref() {
3601 Expr::RecordUpdate(_, base, updates) => {
3602 let (preds, out_ty) = infer_record_update_type_with_hint(
3603 unifier,
3604 supply,
3605 env,
3606 adts,
3607 known,
3608 base.as_ref(),
3609 updates,
3610 &ann_ty,
3611 )?;
3612 Ok((preds, out_ty))
3613 }
3614 _ => {
3615 let (preds, expr_ty) =
3616 infer_expr_type(unifier, supply, env, adts, known, expr)?;
3617 unifier.unify(&expr_ty, &ann_ty)?;
3618 let out_ty = unifier.apply_type(&ann_ty);
3619 Ok((preds, out_ty))
3620 }
3621 }
3622 }
3623 }
3624}
3625
3626fn infer_expr(
3627 unifier: &mut Unifier<'_>,
3628 supply: &mut TypeVarSupply,
3629 env: &TypeEnv,
3630 adts: &HashMap<Symbol, AdtDecl>,
3631 known: &KnownVariants,
3632 expr: &Expr,
3633) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
3634 let span = *expr.span();
3635 let res = unifier.with_infer_depth(span, |unifier| {
3636 (|| {
3637 unifier.charge_infer_node()?;
3638 match expr {
3639 Expr::Bool(_, v) => {
3640 let t = Type::builtin(BuiltinTypeId::Bool);
3641 Ok((
3642 vec![],
3643 t.clone(),
3644 TypedExpr::new(t, TypedExprKind::Bool(*v)),
3645 ))
3646 }
3647 Expr::Uint(_, v) => {
3648 let t = Type::var(supply.fresh(Some(sym("n"))));
3649 Ok((
3650 vec![Predicate::new("Integral", t.clone())],
3651 t.clone(),
3652 TypedExpr::new(t, TypedExprKind::Uint(*v)),
3653 ))
3654 }
3655 Expr::Int(_, v) => {
3656 let t = Type::var(supply.fresh(Some(sym("n"))));
3657 Ok((
3658 vec![
3659 Predicate::new("Integral", t.clone()),
3660 Predicate::new("AdditiveGroup", t.clone()),
3661 ],
3662 t.clone(),
3663 TypedExpr::new(t, TypedExprKind::Int(*v)),
3664 ))
3665 }
3666 Expr::Float(_, v) => {
3667 let t = Type::builtin(BuiltinTypeId::F32);
3668 Ok((
3669 vec![],
3670 t.clone(),
3671 TypedExpr::new(t, TypedExprKind::Float(*v)),
3672 ))
3673 }
3674 Expr::String(_, v) => {
3675 let t = Type::builtin(BuiltinTypeId::String);
3676 Ok((
3677 vec![],
3678 t.clone(),
3679 TypedExpr::new(t, TypedExprKind::String(v.clone())),
3680 ))
3681 }
3682 Expr::Uuid(_, v) => {
3683 let t = Type::builtin(BuiltinTypeId::Uuid);
3684 Ok((
3685 vec![],
3686 t.clone(),
3687 TypedExpr::new(t, TypedExprKind::Uuid(*v)),
3688 ))
3689 }
3690 Expr::DateTime(_, v) => {
3691 let t = Type::builtin(BuiltinTypeId::DateTime);
3692 Ok((
3693 vec![],
3694 t.clone(),
3695 TypedExpr::new(t, TypedExprKind::DateTime(*v)),
3696 ))
3697 }
3698 Expr::Hole(_) => {
3699 let t = Type::var(supply.fresh(Some(sym("hole"))));
3700 Ok((vec![], t.clone(), TypedExpr::new(t, TypedExprKind::Hole)))
3701 }
3702 Expr::Var(var) => {
3703 let schemes = env
3704 .lookup(&var.name)
3705 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
3706 if schemes.len() == 1 {
3707 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
3708 let (preds, t) = instantiate(&scheme, supply);
3709 let typed = TypedExpr::new(
3710 t.clone(),
3711 TypedExprKind::Var {
3712 name: var.name.clone(),
3713 overloads: vec![],
3714 },
3715 );
3716 Ok((preds, t, typed))
3717 } else {
3718 let mut overloads = Vec::new();
3719 for scheme in schemes {
3720 if !scheme.preds.is_empty() {
3727 return Err(TypeError::AmbiguousOverload(var.name.clone()));
3728 }
3729
3730 let scheme = apply_scheme_with_unifier(scheme, unifier);
3731 let (preds, typ) = instantiate(&scheme, supply);
3732 if !preds.is_empty() {
3733 return Err(TypeError::AmbiguousOverload(var.name.clone()));
3734 }
3735 overloads.push(typ);
3736 }
3737 let t = Type::var(supply.fresh(Some(var.name.clone())));
3738 let typed = TypedExpr::new(
3739 t.clone(),
3740 TypedExprKind::Var {
3741 name: var.name.clone(),
3742 overloads,
3743 },
3744 );
3745 Ok((vec![], t, typed))
3746 }
3747 }
3748 Expr::Lam(..) => {
3749 let (params, constraints, body) = collect_lambda_chain(expr);
3750 let mut ann_vars = HashMap::new();
3751 let mut param_tys = Vec::with_capacity(params.len());
3752 for (name, ann) in ¶ms {
3753 let param_ty = match ann {
3754 Some(ann) => {
3755 type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?
3756 }
3757 None => Type::var(supply.fresh(Some(name.clone()))),
3758 };
3759 param_tys.push((name.clone(), param_ty));
3760 }
3761
3762 let mut env1 = env.clone();
3763 let mut known_body = known.clone();
3764 for (name, param_ty) in ¶m_tys {
3765 env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
3766 known_body.remove(name);
3767 }
3768
3769 let (mut preds, body_ty, typed_body) =
3770 infer_expr(unifier, supply, &env1, adts, &known_body, body)?;
3771 let constraint_preds =
3772 predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
3773 preds.extend(constraint_preds);
3774
3775 let mut typed = typed_body;
3776 let mut fun_ty = unifier.apply_type(&body_ty);
3777 for (name, param_ty) in param_tys.iter().rev() {
3778 fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
3779 typed = TypedExpr::new(
3780 fun_ty.clone(),
3781 TypedExprKind::Lam {
3782 param: name.clone(),
3783 body: Box::new(typed),
3784 },
3785 );
3786 }
3787
3788 Ok((preds, fun_ty, typed))
3789 }
3790 Expr::App(..) => {
3791 let (head, args) = collect_app_chain(expr);
3792 let (mut preds, mut func_ty, mut typed) =
3793 infer_expr(unifier, supply, env, adts, known, head)?;
3794 let mut overload_name = None;
3795 let mut overload_candidates = match &typed.kind {
3796 TypedExprKind::Var { name, overloads } if !overloads.is_empty() => {
3797 overload_name = Some(name.clone());
3798 Some(overloads.clone())
3799 }
3800 _ => None,
3801 };
3802 for arg in args {
3803 let expected_arg = match unifier.apply_type(&func_ty).as_ref() {
3804 TypeKind::Fun(arg, _) => Some(arg.clone()),
3805 _ => None,
3806 };
3807 let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
3808 TypeKind::Fun(arg, _) => Some(arg.clone()),
3809 _ => None,
3810 };
3811 let (p_arg, arg_ty, typed_arg) =
3812 infer_app_arg_typed(unifier, supply, env, adts, known, arg_hint, arg)?;
3813 let mut arg_ty = unifier.apply_type(&arg_ty);
3814 let mut typed_arg = typed_arg;
3815
3816 if let Some(expected_arg) = expected_arg {
3819 let expected_arg = unifier.apply_type(&expected_arg);
3820 if let (Some(expected_elem), Some(arg_elem)) = (
3821 unary_app_arg(&expected_arg, "Array"),
3822 unary_app_arg(&arg_ty, "List"),
3823 ) {
3824 unifier.unify(&expected_elem, &arg_elem)?;
3825 let elem_ty = unifier.apply_type(&expected_elem);
3826 let list_ty = Type::list(elem_ty.clone());
3827 let array_ty = Type::array(elem_ty);
3828 let coercion_ty = Type::fun(list_ty, array_ty.clone());
3829 let coercion_fn = TypedExpr::new(
3830 coercion_ty,
3831 TypedExprKind::Var {
3832 name: sym("prim_array_from_list"),
3833 overloads: vec![],
3834 },
3835 );
3836 typed_arg = TypedExpr::new(
3837 array_ty.clone(),
3838 TypedExprKind::App(Box::new(coercion_fn), Box::new(typed_arg)),
3839 );
3840 arg_ty = array_ty;
3841 }
3842 }
3843 if let Some(candidates) = overload_candidates.take() {
3844 let candidates = candidates
3845 .into_iter()
3846 .map(|t| unifier.apply_type(&t))
3847 .collect::<Vec<_>>();
3848 let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
3849 if narrowed.is_empty()
3850 && let Some(name) = &overload_name
3851 {
3852 return Err(TypeError::AmbiguousOverload(name.clone()));
3853 }
3854 overload_candidates = Some(narrowed);
3855 }
3856 let res_ty = match overload_candidates.as_ref() {
3857 Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
3858 _ => Type::var(supply.fresh(Some("r".into()))),
3859 };
3860 unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
3861 let result_ty = match overload_candidates.as_ref() {
3862 Some(candidates) if candidates.len() == 1 => {
3863 unifier.apply_type(&candidates[0])
3864 }
3865 _ => unifier.apply_type(&res_ty),
3866 };
3867 preds.extend(p_arg);
3868 typed = TypedExpr::new(
3869 result_ty.clone(),
3870 TypedExprKind::App(Box::new(typed), Box::new(typed_arg)),
3871 );
3872 func_ty = result_ty;
3873 }
3874 Ok((preds, func_ty, typed))
3875 }
3876 Expr::Project(_, base, field) => {
3877 let (p1, t1, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
3878 let base_ty = unifier.apply_type(&t1);
3879 let known_variant =
3880 known_variant_from_expr_with_known(base, &base_ty, adts, known);
3881 let field_ty =
3882 resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
3883 let typed = TypedExpr::new(
3884 field_ty.clone(),
3885 TypedExprKind::Project {
3886 expr: Box::new(typed_base),
3887 field: field.clone(),
3888 },
3889 );
3890 Ok((p1, field_ty, typed))
3891 }
3892 Expr::RecordUpdate(_, base, updates) => {
3893 let (p_base, t_base, typed_base) =
3894 infer_expr(unifier, supply, env, adts, known, base)?;
3895 let base_ty = unifier.apply_type(&t_base);
3896 let known_variant =
3897 known_variant_from_expr_with_known(base, &base_ty, adts, known);
3898 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
3899 let (result_ty, fields) = resolve_record_update(
3900 unifier,
3901 supply,
3902 adts,
3903 &base_ty,
3904 known_variant,
3905 &update_fields,
3906 )?;
3907 let expected: HashMap<_, _> = fields.into_iter().collect();
3908
3909 let mut preds = p_base;
3910 let mut typed_updates = BTreeMap::new();
3911 for (k, v) in updates {
3912 let expected_ty =
3913 expected.get(k).ok_or_else(|| TypeError::UnknownField {
3914 field: k.clone(),
3915 typ: result_ty.to_string(),
3916 })?;
3917 let (p1, t1, typed_v) =
3918 infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
3919 unifier.unify(&t1, expected_ty)?;
3920 preds.extend(p1);
3921 typed_updates.insert(k.clone(), typed_v);
3922 }
3923 let typed = TypedExpr::new(
3924 result_ty.clone(),
3925 TypedExprKind::RecordUpdate {
3926 base: Box::new(typed_base),
3927 updates: typed_updates,
3928 },
3929 );
3930 Ok((preds, result_ty, typed))
3931 }
3932 Expr::Let(..) => {
3933 let mut bindings = Vec::new();
3934 let mut cur = expr;
3935 while let Expr::Let(_, v, ann, d, b) = cur {
3936 bindings.push((v.clone(), ann.clone(), d.clone()));
3937 cur = b.as_ref();
3938 }
3939
3940 let mut env_cur = env.clone();
3941 let mut known_cur = known.clone();
3942 let mut typed_defs = Vec::new();
3943 for (v, ann, d) in bindings {
3944 let (p1, t1, typed_def) = if let Some(ref ann_expr) = ann {
3945 let mut ann_vars = HashMap::new();
3946 let ann_ty = type_from_annotation_expr_vars(
3947 adts,
3948 ann_expr,
3949 &mut ann_vars,
3950 supply,
3951 )?;
3952 match d.as_ref() {
3953 Expr::RecordUpdate(_, base, updates) => {
3954 infer_record_update_typed_with_hint(
3955 unifier,
3956 supply,
3957 &env_cur,
3958 adts,
3959 &known_cur,
3960 base.as_ref(),
3961 updates,
3962 &ann_ty,
3963 )?
3964 }
3965 _ => {
3966 let (p1, t1, typed_def) = infer_expr(
3967 unifier, supply, &env_cur, adts, &known_cur, &d,
3968 )?;
3969 unifier.unify(&t1, &ann_ty)?;
3970 (p1, t1, typed_def)
3971 }
3972 }
3973 } else {
3974 infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?
3975 };
3976 let def_ty = unifier.apply_type(&t1);
3977 let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
3978 monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
3979 } else {
3980 let scheme =
3981 generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
3982 reject_ambiguous_scheme(&scheme)?;
3983 scheme
3984 };
3985 env_cur.extend(v.name.clone(), scheme);
3986 if let Some(known_variant) =
3987 known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
3988 {
3989 known_cur.insert(
3990 v.name.clone(),
3991 KnownVariant {
3992 adt: known_variant.adt,
3993 variant: known_variant.variant,
3994 },
3995 );
3996 } else {
3997 known_cur.remove(&v.name);
3998 }
3999 typed_defs.push((v.name.clone(), typed_def));
4000 }
4001
4002 let (p_body, t_body, typed_body) =
4003 infer_expr(unifier, supply, &env_cur, adts, &known_cur, cur)?;
4004
4005 let mut typed = typed_body;
4006 for (name, def) in typed_defs.into_iter().rev() {
4007 typed = TypedExpr::new(
4008 t_body.clone(),
4009 TypedExprKind::Let {
4010 name,
4011 def: Box::new(def),
4012 body: Box::new(typed),
4013 },
4014 );
4015 }
4016 Ok((p_body, t_body, typed))
4017 }
4018 Expr::LetRec(_, bindings, body) => {
4019 let mut env_seed = env.clone();
4020 let mut known_seed = known.clone();
4021 let mut binding_tys = HashMap::new();
4022 for (var, _ann, _def) in bindings {
4023 let tv = Type::var(supply.fresh(Some(var.name.clone())));
4024 env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
4025 known_seed.remove(&var.name);
4026 binding_tys.insert(var.name.clone(), tv);
4027 }
4028
4029 let mut inferred_defs = Vec::with_capacity(bindings.len());
4030 for (var, ann, def) in bindings {
4031 let (preds, def_ty, typed_def) =
4032 infer_expr(unifier, supply, &env_seed, adts, &known_seed, def)?;
4033 if let Some(ann) = ann {
4034 let mut ann_vars = HashMap::new();
4035 let ann_ty =
4036 type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
4037 unifier.unify(&def_ty, &ann_ty)?;
4038 }
4039 let binding_ty = binding_tys
4040 .get(&var.name)
4041 .cloned()
4042 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
4043 unifier.unify(&binding_ty, &def_ty)?;
4044 let resolved_ty = unifier.apply_type(&binding_ty);
4045
4046 if let Some(known_variant) =
4047 known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
4048 {
4049 known_seed.insert(
4050 var.name.clone(),
4051 KnownVariant {
4052 adt: known_variant.adt,
4053 variant: known_variant.variant,
4054 },
4055 );
4056 } else {
4057 known_seed.remove(&var.name);
4058 }
4059 inferred_defs.push((var.name.clone(), preds, resolved_ty, typed_def));
4060 }
4061
4062 let mut env_body = env.clone();
4063 let mut typed_bindings = Vec::with_capacity(inferred_defs.len());
4064 for (name, preds, def_ty, typed_def) in inferred_defs {
4065 let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
4066 reject_ambiguous_scheme(&scheme)?;
4067 env_body.extend(name.clone(), scheme);
4068 typed_bindings.push((name, typed_def));
4069 }
4070
4071 let (p_body, t_body, typed_body) =
4072 infer_expr(unifier, supply, &env_body, adts, &known_seed, body)?;
4073 let typed = TypedExpr::new(
4074 t_body.clone(),
4075 TypedExprKind::LetRec {
4076 bindings: typed_bindings,
4077 body: Box::new(typed_body),
4078 },
4079 );
4080 Ok((p_body, t_body, typed))
4081 }
4082 Expr::Ite(_, cond, then_expr, else_expr) => {
4083 let (p1, t1, typed_cond) = infer_expr(unifier, supply, env, adts, known, cond)?;
4084 unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
4085 let (p2, t2, typed_then) =
4086 infer_expr(unifier, supply, env, adts, known, then_expr)?;
4087 let (p3, t3, typed_else) =
4088 infer_expr(unifier, supply, env, adts, known, else_expr)?;
4089 unifier.unify(&t2, &t3)?;
4090 let out_ty = unifier.apply_type(&t2);
4091 let mut preds = p1;
4092 preds.extend(p2);
4093 preds.extend(p3);
4094 let typed = TypedExpr::new(
4095 out_ty.clone(),
4096 TypedExprKind::Ite {
4097 cond: Box::new(typed_cond),
4098 then_expr: Box::new(typed_then),
4099 else_expr: Box::new(typed_else),
4100 },
4101 );
4102 Ok((preds, out_ty, typed))
4103 }
4104 Expr::Tuple(_, elems) => {
4105 let mut preds = Vec::new();
4106 let mut types = Vec::new();
4107 let mut typed_elems = Vec::new();
4108 for elem in elems {
4109 let (p1, t1, typed_elem) =
4110 infer_expr(unifier, supply, env, adts, known, elem)?;
4111 preds.extend(p1);
4112 types.push(unifier.apply_type(&t1));
4113 typed_elems.push(typed_elem);
4114 }
4115 let tuple_ty = Type::tuple(types);
4116 let typed = TypedExpr::new(tuple_ty.clone(), TypedExprKind::Tuple(typed_elems));
4117 Ok((preds, tuple_ty, typed))
4118 }
4119 Expr::List(_, elems) => {
4120 let elem_tv = Type::var(supply.fresh(Some("a".into())));
4121 let mut preds = Vec::new();
4122 let mut typed_elems = Vec::new();
4123 for elem in elems {
4124 let (p1, t1, typed_elem) =
4125 infer_expr(unifier, supply, env, adts, known, elem)?;
4126 unifier.unify(&t1, &elem_tv)?;
4127 preds.extend(p1);
4128 typed_elems.push(typed_elem);
4129 }
4130 let list_ty = Type::app(
4131 Type::builtin(BuiltinTypeId::List),
4132 unifier.apply_type(&elem_tv),
4133 );
4134 let typed = TypedExpr::new(list_ty.clone(), TypedExprKind::List(typed_elems));
4135 Ok((preds, list_ty, typed))
4136 }
4137 Expr::Dict(_, kvs) => {
4138 let elem_tv = Type::var(supply.fresh(Some("v".into())));
4139 let mut preds = Vec::new();
4140 let mut typed_kvs = BTreeMap::new();
4141 for (k, v) in kvs {
4142 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v)?;
4143 unifier.unify(&t1, &elem_tv)?;
4144 preds.extend(p1);
4145 typed_kvs.insert(k.clone(), typed_v);
4146 }
4147 let dict_ty = Type::app(
4148 Type::builtin(BuiltinTypeId::Dict),
4149 unifier.apply_type(&elem_tv),
4150 );
4151 let typed = TypedExpr::new(dict_ty.clone(), TypedExprKind::Dict(typed_kvs));
4152 Ok((preds, dict_ty, typed))
4153 }
4154 Expr::Match(_, scrutinee, arms) => {
4155 let (p1, t1, typed_scrutinee) =
4156 infer_expr(unifier, supply, env, adts, known, scrutinee)?;
4157 let mut preds = p1;
4158 let mut typed_arms = Vec::new();
4159 let res_ty = Type::var(supply.fresh(Some("match".into())));
4160 let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
4161
4162 for (pat, expr) in arms {
4163 let scrutinee_ty = unifier.apply_type(&t1);
4164 let (p_pat, binds) =
4165 infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
4166 preds.extend(p_pat);
4167
4168 let mut env_arm = env.clone();
4169 for (name, ty) in binds {
4170 env_arm
4171 .extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
4172 }
4173 let mut known_arm = known.clone();
4174 if let Expr::Var(var) = scrutinee.as_ref() {
4175 match pat {
4176 Pattern::Named(_, name, _) => {
4177 let name_sym = name.to_dotted_symbol();
4178 if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
4179 known_arm.insert(
4180 var.name.clone(),
4181 KnownVariant {
4182 adt: adt.name.clone(),
4183 variant: name_sym,
4184 },
4185 );
4186 } else {
4187 known_arm.remove(&var.name);
4188 }
4189 }
4190 _ => {
4191 known_arm.remove(&var.name);
4192 }
4193 }
4194 }
4195 let (p_expr, t_expr, typed_expr) =
4196 infer_expr(unifier, supply, &env_arm, adts, &known_arm, expr)?;
4197 unifier.unify(&res_ty, &t_expr)?;
4198 preds.extend(p_expr);
4199 typed_arms.push((pat.clone(), typed_expr));
4200 }
4201
4202 let scrutinee_ty = unifier.apply_type(&t1);
4203 check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
4204 let out_ty = unifier.apply_type(&res_ty);
4205 let typed = TypedExpr::new(
4206 out_ty.clone(),
4207 TypedExprKind::Match {
4208 scrutinee: Box::new(typed_scrutinee),
4209 arms: typed_arms,
4210 },
4211 );
4212 Ok((preds, out_ty, typed))
4213 }
4214 Expr::Ann(_, expr, ann) => {
4215 let ann_ty = type_from_annotation_expr(adts, ann)?;
4216 match expr.as_ref() {
4217 Expr::RecordUpdate(_, base, updates) => {
4218 infer_record_update_typed_with_hint(
4219 unifier,
4220 supply,
4221 env,
4222 adts,
4223 known,
4224 base.as_ref(),
4225 updates,
4226 &ann_ty,
4227 )
4228 }
4229 _ => {
4230 let (preds, expr_ty, typed_expr) =
4231 infer_expr(unifier, supply, env, adts, known, expr)?;
4232 unifier.unify(&expr_ty, &ann_ty)?;
4233 let out_ty = unifier.apply_type(&ann_ty);
4234 Ok((preds, out_ty, typed_expr))
4235 }
4236 }
4237 }
4238 }
4239 })()
4240 });
4241 res.map_err(|err| with_span(&span, err))
4242}
4243
4244fn ctor_lookup<'a>(
4245 adts: &'a HashMap<Symbol, AdtDecl>,
4246 name: &Symbol,
4247) -> Option<(&'a AdtDecl, &'a AdtVariant)> {
4248 let mut found = None;
4249 for adt in adts.values() {
4250 if let Some(variant) = adt.variants.iter().find(|v| &v.name == name) {
4251 if found.is_some() {
4252 return None;
4253 }
4254 found = Some((adt, variant));
4255 }
4256 }
4257 found
4258}
4259
4260fn record_fields(variant: &AdtVariant) -> Option<&[(Symbol, Type)]> {
4261 if variant.args.len() != 1 {
4262 return None;
4263 }
4264 match variant.args[0].as_ref() {
4265 TypeKind::Record(fields) => Some(fields),
4266 _ => None,
4267 }
4268}
4269
4270fn instantiate_variant_fields(
4271 adt: &AdtDecl,
4272 variant: &AdtVariant,
4273 supply: &mut TypeVarSupply,
4274) -> Option<(Type, Vec<(Symbol, Type)>)> {
4275 let fields = record_fields(variant)?;
4276 let mut subst = Subst::new_sync();
4277 for param in &adt.params {
4278 let fresh = Type::var(supply.fresh(param.var.name.clone()));
4279 subst = subst.insert(param.var.id, fresh);
4280 }
4281 let result_ty = adt.result_type().apply(&subst);
4282 let fields = fields
4283 .iter()
4284 .map(|(name, ty)| (name.clone(), ty.apply(&subst)))
4285 .collect();
4286 Some((result_ty, fields))
4287}
4288
4289fn known_variant_from_expr(
4290 expr: &Expr,
4291 expr_ty: &Type,
4292 adts: &HashMap<Symbol, AdtDecl>,
4293) -> Option<KnownVariant> {
4294 let mut expr = expr;
4295 while let Expr::Ann(_, inner, _) = expr {
4296 expr = inner.as_ref();
4297 }
4298 if matches!(expr_ty.as_ref(), TypeKind::Fun(..)) {
4299 return None;
4300 }
4301 let ctor = match expr {
4302 Expr::App(_, f, _) => match f.as_ref() {
4303 Expr::Var(var) => var.name.clone(),
4304 _ => return None,
4305 },
4306 _ => return None,
4307 };
4308 let (adt, variant) = ctor_lookup(adts, &ctor)?;
4309 record_fields(variant)?;
4310 Some(KnownVariant {
4311 adt: adt.name.clone(),
4312 variant: variant.name.clone(),
4313 })
4314}
4315
4316fn known_variant_from_expr_with_known(
4317 expr: &Expr,
4318 expr_ty: &Type,
4319 adts: &HashMap<Symbol, AdtDecl>,
4320 known: &KnownVariants,
4321) -> Option<KnownVariant> {
4322 let mut expr = expr;
4323 while let Expr::Ann(_, inner, _) = expr {
4324 expr = inner.as_ref();
4325 }
4326 match expr {
4327 Expr::Var(var) => known.get(&var.name).cloned(),
4328 Expr::RecordUpdate(_, base, _) => {
4329 known_variant_from_expr_with_known(base.as_ref(), expr_ty, adts, known)
4330 }
4331 _ => known_variant_from_expr(expr, expr_ty, adts),
4332 }
4333}
4334
4335fn select_record_variant<'a, F>(
4336 adts: &'a HashMap<Symbol, AdtDecl>,
4337 base_ty: &Type,
4338 known_variant: Option<KnownVariant>,
4339 field_for_errors: &Symbol,
4340 matches_fields: F,
4341) -> Result<(&'a AdtDecl, &'a AdtVariant), TypeError>
4342where
4343 F: Fn(&[(Symbol, Type)]) -> bool,
4344{
4345 if let Some(info) = known_variant {
4346 let adt = adts
4347 .get(&info.adt)
4348 .ok_or_else(|| TypeError::UnknownTypeName(info.adt.clone()))?;
4349 let variant = adt
4350 .variants
4351 .iter()
4352 .find(|v| v.name == info.variant)
4353 .ok_or_else(|| TypeError::UnknownField {
4354 field: field_for_errors.clone(),
4355 typ: base_ty.to_string(),
4356 })?;
4357 return Ok((adt, variant));
4358 }
4359
4360 if let Some(adt_name) = type_head_name(base_ty) {
4361 let adt = adts.get(adt_name).ok_or_else(|| TypeError::UnknownField {
4362 field: field_for_errors.clone(),
4363 typ: base_ty.to_string(),
4364 })?;
4365 if adt.variants.len() == 1 {
4366 return Ok((adt, &adt.variants[0]));
4367 }
4368 return Err(TypeError::FieldNotKnown {
4369 field: field_for_errors.clone(),
4370 typ: base_ty.to_string(),
4371 });
4372 }
4373
4374 if matches!(base_ty.as_ref(), TypeKind::Var(_)) {
4375 let mut candidates = Vec::new();
4376 for adt in adts.values() {
4377 if adt.variants.len() != 1 {
4378 continue;
4379 }
4380 let variant = &adt.variants[0];
4381 let Some(fields) = record_fields(variant) else {
4382 continue;
4383 };
4384 if matches_fields(fields) {
4385 candidates.push((adt, variant));
4386 }
4387 }
4388 if candidates.len() == 1 {
4389 return Ok(candidates.remove(0));
4390 }
4391 if candidates.is_empty() {
4392 return Err(TypeError::UnknownField {
4393 field: field_for_errors.clone(),
4394 typ: base_ty.to_string(),
4395 });
4396 }
4397 return Err(TypeError::FieldNotKnown {
4398 field: field_for_errors.clone(),
4399 typ: base_ty.to_string(),
4400 });
4401 }
4402
4403 Err(TypeError::UnknownField {
4404 field: field_for_errors.clone(),
4405 typ: base_ty.to_string(),
4406 })
4407}
4408
4409fn resolve_record_update(
4410 unifier: &mut Unifier<'_>,
4411 supply: &mut TypeVarSupply,
4412 adts: &HashMap<Symbol, AdtDecl>,
4413 base_ty: &Type,
4414 known_variant: Option<KnownVariant>,
4415 update_fields: &[Symbol],
4416) -> Result<(Type, Vec<(Symbol, Type)>), TypeError> {
4417 if let TypeKind::Record(fields) = base_ty.as_ref() {
4418 return Ok((base_ty.clone(), fields.clone()));
4419 }
4420
4421 let field_for_errors = update_fields.first().cloned().unwrap_or_else(|| sym("_"));
4422
4423 let (adt, variant) =
4424 select_record_variant(adts, base_ty, known_variant, &field_for_errors, |fields| {
4425 update_fields
4426 .iter()
4427 .all(|field| fields.iter().any(|(name, _)| name == field))
4428 })?;
4429
4430 let (result_ty, fields) =
4431 instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
4432 TypeError::UnknownField {
4433 field: field_for_errors.clone(),
4434 typ: base_ty.to_string(),
4435 }
4436 })?;
4437
4438 for field in update_fields {
4439 if fields.iter().all(|(name, _)| name != field) {
4440 return Err(TypeError::UnknownField {
4441 field: field.clone(),
4442 typ: base_ty.to_string(),
4443 });
4444 }
4445 }
4446
4447 unifier.unify(base_ty, &result_ty)?;
4448 let result_ty = unifier.apply_type(&result_ty);
4449 let fields = fields
4450 .into_iter()
4451 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4452 .collect();
4453 Ok((result_ty, fields))
4454}
4455
4456fn resolve_projection(
4457 unifier: &mut Unifier<'_>,
4458 supply: &mut TypeVarSupply,
4459 adts: &HashMap<Symbol, AdtDecl>,
4460 base_ty: &Type,
4461 known_variant: Option<KnownVariant>,
4462 field: &Symbol,
4463) -> Result<Type, TypeError> {
4464 if let Ok(index) = field.as_ref().parse::<usize>() {
4465 let elem_ty = match base_ty.as_ref() {
4466 TypeKind::Tuple(elems) => {
4467 elems
4468 .get(index)
4469 .cloned()
4470 .ok_or_else(|| TypeError::UnknownField {
4471 field: field.clone(),
4472 typ: base_ty.to_string(),
4473 })?
4474 }
4475 TypeKind::Var(_) => {
4476 let mut elems = Vec::with_capacity(index + 1);
4477 for _ in 0..=index {
4478 elems.push(Type::var(supply.fresh(Some(sym("t")))));
4479 }
4480 let tuple_ty = Type::tuple(elems.clone());
4481 unifier.unify(base_ty, &tuple_ty)?;
4482 elems[index].clone()
4483 }
4484 _ => {
4485 return Err(TypeError::UnknownField {
4486 field: field.clone(),
4487 typ: base_ty.to_string(),
4488 });
4489 }
4490 };
4491 return Ok(unifier.apply_type(&elem_ty));
4492 }
4493
4494 let (adt, variant) = select_record_variant(adts, base_ty, known_variant, field, |fields| {
4495 fields.iter().any(|(name, _)| name == field)
4496 })?;
4497
4498 let (result_ty, fields) =
4499 instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
4500 TypeError::UnknownField {
4501 field: field.clone(),
4502 typ: base_ty.to_string(),
4503 }
4504 })?;
4505 let field_ty = fields
4506 .iter()
4507 .find(|(name, _)| name == field)
4508 .map(|(_, ty)| ty.clone())
4509 .ok_or_else(|| TypeError::UnknownField {
4510 field: field.clone(),
4511 typ: base_ty.to_string(),
4512 })?;
4513 unifier.unify(base_ty, &result_ty)?;
4514 Ok(unifier.apply_type(&field_ty))
4515}
4516
4517fn decompose_fun(typ: &Type, arity: usize) -> Option<(Vec<Type>, Type)> {
4518 let mut args = Vec::with_capacity(arity);
4519 let mut cur = typ.clone();
4520 for _ in 0..arity {
4521 match cur.as_ref() {
4522 TypeKind::Fun(a, b) => {
4523 args.push(a.clone());
4524 cur = b.clone();
4525 }
4526 _ => return None,
4527 }
4528 }
4529 Some((args, cur))
4530}
4531
4532type InferPatternResult = (Vec<Predicate>, Vec<(Symbol, Type)>);
4533
4534fn infer_pattern(
4535 unifier: &mut Unifier<'_>,
4536 supply: &mut TypeVarSupply,
4537 env: &TypeEnv,
4538 pat: &Pattern,
4539 scrutinee_ty: &Type,
4540) -> Result<InferPatternResult, TypeError> {
4541 let span = *pat.span();
4542 let res = (|| {
4543 unifier.charge_infer_node()?;
4544 match pat {
4545 Pattern::Wildcard(..) => Ok((vec![], vec![])),
4546 Pattern::Var(var) => Ok((
4547 vec![],
4548 vec![(var.name.clone(), unifier.apply_type(scrutinee_ty))],
4549 )),
4550 Pattern::Named(_, name, ps) => {
4551 let ctor_name = name.to_dotted_symbol();
4552 let schemes = env
4553 .lookup(&ctor_name)
4554 .ok_or_else(|| TypeError::UnknownVar(ctor_name.clone()))?;
4555 if schemes.len() != 1 {
4556 return Err(TypeError::AmbiguousOverload(ctor_name));
4557 }
4558 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
4559 let (preds, ctor_ty) = instantiate(&scheme, supply);
4560 let (arg_tys, res_ty) = decompose_fun(&ctor_ty, ps.len())
4561 .ok_or(TypeError::UnsupportedExpr("pattern constructor"))?;
4562 unifier.unify(&res_ty, scrutinee_ty)?;
4563 let mut all_preds = preds;
4564 let mut bindings = Vec::new();
4565 for (p, arg_ty) in ps.iter().zip(arg_tys.iter()) {
4566 let arg_ty = unifier.apply_type(arg_ty);
4567 let (p1, binds1) = infer_pattern(unifier, supply, env, p, &arg_ty)?;
4568 all_preds.extend(p1);
4569 bindings.extend(binds1);
4570 }
4571 let bindings = bindings
4572 .into_iter()
4573 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4574 .collect();
4575 Ok((all_preds, bindings))
4576 }
4577 Pattern::List(_, ps) => {
4578 let elem_tv = Type::var(supply.fresh(Some("a".into())));
4579 let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
4580 unifier.unify(scrutinee_ty, &list_ty)?;
4581 let mut preds = Vec::new();
4582 let mut bindings = Vec::new();
4583 for p in ps {
4584 let elem_ty = unifier.apply_type(&elem_tv);
4585 let (p1, binds1) = infer_pattern(unifier, supply, env, p, &elem_ty)?;
4586 preds.extend(p1);
4587 bindings.extend(binds1);
4588 }
4589 let bindings = bindings
4590 .into_iter()
4591 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4592 .collect();
4593 Ok((preds, bindings))
4594 }
4595 Pattern::Cons(_, head, tail) => {
4596 let elem_tv = Type::var(supply.fresh(Some("a".into())));
4597 let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
4598 unifier.unify(scrutinee_ty, &list_ty)?;
4599 let mut preds = Vec::new();
4600 let mut bindings = Vec::new();
4601
4602 let head_ty = unifier.apply_type(&elem_tv);
4603 let (p1, binds1) = infer_pattern(unifier, supply, env, head, &head_ty)?;
4604 preds.extend(p1);
4605 bindings.extend(binds1);
4606
4607 let tail_ty = Type::app(
4608 Type::builtin(BuiltinTypeId::List),
4609 unifier.apply_type(&elem_tv),
4610 );
4611 let (p2, binds2) = infer_pattern(unifier, supply, env, tail, &tail_ty)?;
4612 preds.extend(p2);
4613 bindings.extend(binds2);
4614
4615 let bindings = bindings
4616 .into_iter()
4617 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4618 .collect();
4619 Ok((preds, bindings))
4620 }
4621 Pattern::Tuple(_, elems) => {
4622 let mut elem_tys: Vec<Type> = (0..elems.len())
4624 .map(|i| Type::var(supply.fresh(Some(format!("t{i}").into()))))
4625 .collect();
4626 let expected = Type::tuple(elem_tys.clone());
4627 unifier.unify(scrutinee_ty, &expected)?;
4628 elem_tys = elem_tys
4629 .into_iter()
4630 .map(|t| unifier.apply_type(&t))
4631 .collect();
4632
4633 let mut preds = Vec::new();
4634 let mut bindings = Vec::new();
4635 for (p, ty) in elems.iter().zip(elem_tys.iter()) {
4636 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, p, ty)?;
4637 preds.extend(p_preds);
4638 bindings.extend(p_binds);
4639 }
4640 let bindings = bindings
4641 .into_iter()
4642 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4643 .collect();
4644 Ok((preds, bindings))
4645 }
4646 Pattern::Dict(_, fields) => {
4647 if let TypeKind::Record(ty_fields) = scrutinee_ty.as_ref() {
4648 let mut preds = Vec::new();
4649 let mut bindings = Vec::new();
4650 for (key, pat) in fields {
4651 let ty = ty_fields
4652 .iter()
4653 .find(|(name, _)| name == key)
4654 .map(|(_, ty)| unifier.apply_type(ty))
4655 .ok_or_else(|| TypeError::UnknownField {
4656 field: key.clone(),
4657 typ: scrutinee_ty.to_string(),
4658 })?;
4659 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &ty)?;
4660 preds.extend(p_preds);
4661 bindings.extend(p_binds);
4662 }
4663 let bindings = bindings
4664 .into_iter()
4665 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4666 .collect();
4667 Ok((preds, bindings))
4668 } else {
4669 let elem_tv = Type::var(supply.fresh(Some("v".into())));
4670 let dict_ty = Type::app(Type::builtin(BuiltinTypeId::Dict), elem_tv.clone());
4671 unifier.unify(scrutinee_ty, &dict_ty)?;
4672 let elem_ty = unifier.apply_type(&elem_tv);
4673
4674 let mut preds = Vec::new();
4675 let mut bindings = Vec::new();
4676 for (_key, pat) in fields {
4677 let (p_preds, p_binds) =
4678 infer_pattern(unifier, supply, env, pat, &elem_ty)?;
4679 preds.extend(p_preds);
4680 bindings.extend(p_binds);
4681 }
4682 let bindings = bindings
4683 .into_iter()
4684 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
4685 .collect();
4686 Ok((preds, bindings))
4687 }
4688 }
4689 }
4690 })();
4691 res.map_err(|err| with_span(&span, err))
4692}
4693
4694fn type_head_name(typ: &Type) -> Option<&Symbol> {
4695 let mut cur = typ;
4696 while let TypeKind::App(head, _) = cur.as_ref() {
4697 cur = head;
4698 }
4699 match cur.as_ref() {
4700 TypeKind::Con(tc) => Some(&tc.name),
4701 _ => None,
4702 }
4703}
4704
4705#[derive(Clone, Debug, PartialEq, Eq)]
4706pub struct AdtConflict {
4707 pub name: Symbol,
4708 pub definitions: Vec<Type>,
4709}
4710
4711#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)]
4712#[error("conflicting ADT definitions: {conflicts:?}")]
4713pub struct CollectAdtsError {
4714 pub conflicts: Vec<AdtConflict>,
4715}
4716
4717pub fn collect_adts_in_types(types: Vec<Type>) -> Result<Vec<Type>, CollectAdtsError> {
4754 fn visit(
4755 typ: &Type,
4756 out: &mut Vec<Type>,
4757 seen: &mut HashSet<Type>,
4758 defs_by_name: &mut BTreeMap<Symbol, Vec<Type>>,
4759 ) {
4760 match typ.as_ref() {
4761 TypeKind::Var(_) => {}
4762 TypeKind::Con(tc) => {
4763 if tc.builtin_id.is_none() {
4765 let adt = Type::new(TypeKind::Con(tc.clone()));
4766 if seen.insert(adt.clone()) {
4767 out.push(adt.clone());
4768 }
4769 let defs = defs_by_name.entry(tc.name.clone()).or_default();
4770 if !defs.contains(&adt) {
4771 defs.push(adt);
4772 }
4773 }
4774 }
4775 TypeKind::App(fun, arg) => {
4776 visit(fun, out, seen, defs_by_name);
4777 visit(arg, out, seen, defs_by_name);
4778 }
4779 TypeKind::Fun(arg, ret) => {
4780 visit(arg, out, seen, defs_by_name);
4781 visit(ret, out, seen, defs_by_name);
4782 }
4783 TypeKind::Tuple(elems) => {
4784 for elem in elems {
4785 visit(elem, out, seen, defs_by_name);
4786 }
4787 }
4788 TypeKind::Record(fields) => {
4789 for (_name, field_ty) in fields {
4790 visit(field_ty, out, seen, defs_by_name);
4791 }
4792 }
4793 }
4794 }
4795
4796 let mut out = Vec::new();
4797 let mut seen = HashSet::new();
4798 let mut defs_by_name: BTreeMap<Symbol, Vec<Type>> = BTreeMap::new();
4799 for typ in &types {
4800 visit(typ, &mut out, &mut seen, &mut defs_by_name);
4801 }
4802
4803 let conflicts: Vec<AdtConflict> = defs_by_name
4804 .into_iter()
4805 .filter_map(|(name, definitions)| {
4806 (definitions.len() > 1).then_some(AdtConflict { name, definitions })
4807 })
4808 .collect();
4809 if !conflicts.is_empty() {
4810 return Err(CollectAdtsError { conflicts });
4811 }
4812
4813 Ok(out)
4814}
4815
4816fn adt_name_from_patterns(adts: &HashMap<Symbol, AdtDecl>, patterns: &[Pattern]) -> Option<Symbol> {
4817 let mut candidate: Option<Symbol> = None;
4818 for pat in patterns {
4819 let next = match pat {
4820 Pattern::Named(_, name, _) => {
4821 let name_sym = name.to_dotted_symbol();
4822 ctor_lookup(adts, &name_sym).map(|(adt, _)| adt.name.clone())
4823 }
4824 Pattern::List(..) | Pattern::Cons(..) => Some(sym("List")),
4825 _ => None,
4826 };
4827 if let Some(next) = next {
4828 match &candidate {
4829 None => candidate = Some(next),
4830 Some(prev) if *prev == next => {}
4831 Some(_) => return None,
4832 }
4833 }
4834 }
4835 candidate
4836}
4837
4838fn check_match_exhaustive(
4839 adts: &HashMap<Symbol, AdtDecl>,
4840 scrutinee_ty: &Type,
4841 patterns: &[Pattern],
4842) -> Result<(), TypeError> {
4843 if patterns
4844 .iter()
4845 .any(|p| matches!(p, Pattern::Wildcard(..) | Pattern::Var(_)))
4846 {
4847 return Ok(());
4848 }
4849 let adt_name = match type_head_name(scrutinee_ty).cloned() {
4850 Some(name) => name,
4851 None => match adt_name_from_patterns(adts, patterns) {
4852 Some(name) => name,
4853 None => return Ok(()),
4854 },
4855 };
4856 let adt = match adts.get(&adt_name) {
4857 Some(adt) => adt,
4858 None => return Ok(()),
4859 };
4860 let ctor_names: HashSet<Symbol> = adt.variants.iter().map(|v| v.name.clone()).collect();
4861 if ctor_names.is_empty() {
4862 return Ok(());
4863 }
4864 let mut covered = HashSet::new();
4865 for pat in patterns {
4866 match pat {
4867 Pattern::Named(_, name, _) => {
4868 let name_sym = name.to_dotted_symbol();
4869 if ctor_names.contains(&name_sym) {
4870 covered.insert(name_sym);
4871 }
4872 }
4873 Pattern::List(_, elems) if adt_name.as_ref() == "List" && elems.is_empty() => {
4874 covered.insert(sym("Empty"));
4875 }
4876 Pattern::Cons(..) if adt_name.as_ref() == "List" => {
4877 covered.insert(sym("Cons"));
4878 }
4879 _ => {}
4880 }
4881 }
4882 let mut missing: Vec<Symbol> = ctor_names.difference(&covered).cloned().collect();
4883 if missing.is_empty() {
4884 return Ok(());
4885 }
4886 missing.sort();
4887 Err(TypeError::NonExhaustiveMatch {
4888 typ: scrutinee_ty.to_string(),
4889 missing,
4890 })
4891}
4892
4893#[cfg(test)]
4894mod tests {
4895 use super::*;
4896 use rexlang_lexer::Token;
4897 use rexlang_parser::Parser;
4898 use rexlang_util::{GasCosts, GasMeter};
4899
4900 fn tvar(id: TypeVarId, name: &str) -> Type {
4901 Type::var(TypeVar::new(id, Some(sym(name))))
4902 }
4903
4904 fn dict_of(elem: Type) -> Type {
4905 Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
4906 }
4907
4908 #[test]
4909 fn unify_simple() {
4910 let t1 = Type::fun(tvar(0, "a"), Type::builtin(BuiltinTypeId::U32));
4911 let t2 = Type::fun(Type::builtin(BuiltinTypeId::U16), tvar(1, "b"));
4912 let subst = unify(&t1, &t2).unwrap();
4913 assert_eq!(subst.get(&0), Some(&Type::builtin(BuiltinTypeId::U16)));
4914 assert_eq!(subst.get(&1), Some(&Type::builtin(BuiltinTypeId::U32)));
4915 }
4916
4917 #[test]
4918 fn occurs_check_blocks_infinite_type() {
4919 let tv = TypeVar::new(0, Some(sym("a")));
4920 let t = Type::fun(Type::var(tv.clone()), Type::builtin(BuiltinTypeId::U8));
4921 let err = bind(&tv, &t).unwrap_err();
4922 assert!(matches!(err, TypeError::Occurs(_, _)));
4923 }
4924
4925 #[test]
4926 fn instantiate_and_generalize_round_trip() {
4927 let mut supply = TypeVarSupply::new();
4928 let a = Type::var(supply.fresh(Some(sym("a"))));
4929 let scheme = generalize(&TypeEnv::new(), vec![], Type::fun(a.clone(), a.clone()));
4930 let (preds, inst) = instantiate(&scheme, &mut supply);
4931 assert!(preds.is_empty());
4932 if let TypeKind::Fun(l, r) = inst.as_ref() {
4933 match (l.as_ref(), r.as_ref()) {
4934 (TypeKind::Var(_), TypeKind::Var(_)) => {}
4935 _ => panic!("expected polymorphic identity"),
4936 }
4937 } else {
4938 panic!("expected function type");
4939 }
4940 }
4941
4942 #[test]
4943 fn entail_superclasses() {
4944 let ts = TypeSystem::with_prelude().unwrap();
4945 let pred = Predicate::new("Semiring", Type::builtin(BuiltinTypeId::I32));
4946 let given = [Predicate::new(
4947 "AdditiveGroup",
4948 Type::builtin(BuiltinTypeId::I32),
4949 )];
4950 assert!(entails(&ts.classes, &given, &pred).unwrap());
4951 }
4952
4953 #[test]
4954 fn entail_instances() {
4955 let ts = TypeSystem::with_prelude().unwrap();
4956 let pred = Predicate::new("Field", Type::builtin(BuiltinTypeId::F32));
4957 assert!(entails(&ts.classes, &[], &pred).unwrap());
4958
4959 let pred_fail = Predicate::new("Field", Type::builtin(BuiltinTypeId::U32));
4960 assert!(!entails(&ts.classes, &[], &pred_fail).unwrap());
4961 }
4962
4963 #[test]
4964 fn prelude_injects_functions() {
4965 let ts = TypeSystem::with_prelude().unwrap();
4966 let minus = ts.env.lookup(&sym("-")).expect("minus in env");
4967 let div = ts.env.lookup(&sym("/")).expect("div in env");
4968 assert_eq!(minus.len(), 1);
4969 assert_eq!(div.len(), 1);
4970 let minus = &minus[0];
4971 let div = &div[0];
4972 assert_eq!(minus.preds.len(), 1);
4973 assert_eq!(minus.vars.len(), 1);
4974 assert_eq!(div.preds.len(), 1);
4975 assert_eq!(div.vars.len(), 1);
4976 }
4977
4978 #[test]
4979 fn adt_constructors_are_present() {
4980 let ts = TypeSystem::with_prelude().unwrap();
4981 assert!(ts.env.lookup(&sym("Empty")).is_some());
4982 assert!(ts.env.lookup(&sym("Cons")).is_some());
4983 assert!(ts.env.lookup(&sym("Ok")).is_some());
4984 assert!(ts.env.lookup(&sym("Err")).is_some());
4985 assert!(ts.env.lookup(&sym("Some")).is_some());
4986 assert!(ts.env.lookup(&sym("None")).is_some());
4987 }
4988
4989 fn parse_expr(code: &str) -> std::sync::Arc<rexlang_ast::expr::Expr> {
4990 let mut parser = Parser::new(Token::tokenize(code).unwrap());
4991 parser.parse_program(&mut GasMeter::default()).unwrap().expr
4992 }
4993
4994 fn parse_program(code: &str) -> rexlang_ast::expr::Program {
4995 let mut parser = Parser::new(Token::tokenize(code).unwrap());
4996 parser.parse_program(&mut GasMeter::default()).unwrap()
4997 }
4998
4999 #[test]
5000 fn infer_deep_list_does_not_overflow() {
5001 const N: usize = 40;
5003 let mut code = String::new();
5004 code.push_str("let xs = ");
5005 for _ in 0..N {
5006 code.push_str("Cons 0 (");
5007 }
5008 code.push_str("Empty");
5009 for _ in 0..N {
5010 code.push(')');
5011 }
5012 code.push_str(" in xs");
5013
5014 let parse_handle = std::thread::Builder::new()
5015 .name("infer_deep_list_parse".into())
5016 .stack_size(128 * 1024 * 1024)
5017 .spawn(move || {
5018 let tokens = Token::tokenize(&code).unwrap();
5019 let mut parser = Parser::new(tokens);
5020 parser.parse_program(&mut GasMeter::default())
5021 })
5022 .unwrap();
5023 let program = parse_handle.join().unwrap().unwrap();
5024 let expr = program.expr;
5025 let mut ts = TypeSystem::with_prelude().unwrap();
5026 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5027 assert_eq!(
5028 ty,
5029 Type::app(
5030 Type::builtin(BuiltinTypeId::List),
5031 Type::builtin(BuiltinTypeId::I32)
5032 )
5033 );
5034 }
5035
5036 #[test]
5037 fn collect_adts_in_types_finds_nested_unique_adts() {
5038 let foo = Type::user_con("Foo", 1);
5039 let bar = Type::user_con("Bar", 0);
5040 let ty = Type::fun(
5041 Type::app(
5042 Type::builtin(BuiltinTypeId::List),
5043 Type::app(foo.clone(), tvar(0, "a")),
5044 ),
5045 Type::tuple(vec![
5046 Type::app(foo.clone(), Type::builtin(BuiltinTypeId::I32)),
5047 bar.clone(),
5048 ]),
5049 );
5050
5051 let adts = collect_adts_in_types(vec![ty]).unwrap();
5052 assert_eq!(adts, vec![foo, bar]);
5053 }
5054
5055 #[test]
5056 fn collect_adts_in_types_rejects_conflicting_names() {
5057 let arity1 = Type::user_con("Thing", 1);
5058 let arity2 = Type::user_con("Thing", 2);
5059
5060 let err = collect_adts_in_types(vec![arity1.clone(), arity2.clone()]).unwrap_err();
5061 assert_eq!(err.conflicts.len(), 1);
5062 let conflict = &err.conflicts[0];
5063 assert_eq!(conflict.name, sym("Thing"));
5064 assert_eq!(conflict.definitions, vec![arity1, arity2]);
5065 }
5066
5067 #[test]
5068 fn infer_depth_limit_is_enforced() {
5069 const N: usize = 40;
5070 let mut code = String::new();
5071 code.push_str("let xs = ");
5072 for _ in 0..N {
5073 code.push_str("Cons 0 (");
5074 }
5075 code.push_str("Empty");
5076 for _ in 0..N {
5077 code.push(')');
5078 }
5079 code.push_str(" in xs");
5080
5081 let program = parse_program(&code);
5082 let mut ts = TypeSystem::with_prelude().unwrap();
5083 ts.set_limits(TypeSystemLimits {
5084 max_infer_depth: Some(8),
5085 });
5086
5087 let err = ts.infer(program.expr.as_ref()).unwrap_err();
5088 assert!(
5089 err.to_string().contains("maximum inference depth exceeded"),
5090 "expected a max-depth inference error, got: {err:?}"
5091 );
5092 }
5093
5094 #[test]
5095 fn declare_fn_injects_scheme_for_use_sites() {
5096 let program = parse_program(
5097 r#"
5098 declare fn id x: a -> a
5099 id 1
5100 "#,
5101 );
5102 let mut ts = TypeSystem::with_prelude().unwrap();
5103 ts.inject_decls(&program.decls).unwrap();
5104 let (preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5105 assert!(
5106 preds.is_empty()
5107 || preds.iter().all(|p| p.class.as_ref() == "Integral"
5108 && p.typ == Type::builtin(BuiltinTypeId::I32))
5109 );
5110 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5111 }
5112
5113 #[test]
5114 fn declare_fn_is_noop_when_matching_existing_scheme() {
5115 let mut ts = TypeSystem::with_prelude().unwrap();
5116 ts.add_value(
5117 "foo",
5118 Scheme::new(
5119 vec![],
5120 vec![],
5121 Type::fun(
5122 Type::builtin(BuiltinTypeId::I32),
5123 Type::builtin(BuiltinTypeId::I32),
5124 ),
5125 ),
5126 );
5127
5128 let program = parse_program(
5129 r#"
5130 declare fn foo x: i32 -> i32
5131 0
5132 "#,
5133 );
5134 let rexlang_ast::expr::Decl::DeclareFn(fd) = &program.decls[0] else {
5135 panic!("expected declare fn decl");
5136 };
5137 ts.inject_declare_fn_decl(fd).unwrap();
5138 }
5139
5140 #[test]
5141 fn unit_type_parses_and_infers() {
5142 let program = parse_program(
5143 r#"
5144 fn unit_id x: () -> () = x
5145 unit_id ()
5146 "#,
5147 );
5148 let mut ts = TypeSystem::with_prelude().unwrap();
5149 ts.inject_decls(&program.decls).unwrap();
5150 let (preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5151 assert!(preds.is_empty());
5152 assert_eq!(ty, Type::tuple(vec![]));
5153 }
5154
5155 fn strip_span(mut err: TypeError) -> TypeError {
5156 while let TypeError::Spanned { error, .. } = err {
5157 err = *error;
5158 }
5159 err
5160 }
5161
5162 #[test]
5163 fn type_errors_include_span() {
5164 let expr = parse_expr("missing");
5165 let mut ts = TypeSystem::with_prelude().unwrap();
5166 let err = ts.infer(expr.as_ref()).unwrap_err();
5167 match err {
5168 TypeError::Spanned { span, error } => {
5169 assert_ne!(span, Span::default());
5170 assert!(matches!(
5171 *error,
5172 TypeError::UnknownVar(name) if name.as_ref() == "missing"
5173 ));
5174 }
5175 other => panic!("expected spanned error, got {other:?}"),
5176 }
5177 }
5178
5179 #[test]
5180 fn infer_with_gas_rejects_out_of_budget() {
5181 let expr = parse_expr("1");
5182 let mut ts = TypeSystem::with_prelude().unwrap();
5183 let mut gas = GasMeter::new(
5184 Some(0),
5185 GasCosts {
5186 infer_node: 1,
5187 unify_step: 0,
5188 ..GasCosts::sensible_defaults()
5189 },
5190 );
5191 let err = ts.infer_with_gas(expr.as_ref(), &mut gas).unwrap_err();
5192 assert!(matches!(strip_span(err), TypeError::OutOfGas(..)));
5193 }
5194
5195 #[test]
5196 fn reject_user_redefinition_of_primitive_type_name() {
5197 let program = parse_program("type i32 = I32Wrap i32");
5198 let mut ts = TypeSystem::with_prelude().unwrap();
5199 let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
5200 panic!("expected type decl");
5201 };
5202 let err = ts.inject_type_decl(decl).unwrap_err();
5203 assert!(matches!(
5204 err,
5205 TypeError::ReservedTypeName(name) if name.as_ref() == "i32"
5206 ));
5207 }
5208
5209 #[test]
5210 fn reject_user_redefinition_of_prelude_adt_name() {
5211 let program = parse_program("type Result e a = Nope e a");
5212 let mut ts = TypeSystem::with_prelude().unwrap();
5213 let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
5214 panic!("expected type decl");
5215 };
5216 let err = ts.inject_type_decl(decl).unwrap_err();
5217 assert!(matches!(
5218 err,
5219 TypeError::ReservedTypeName(name) if name.as_ref() == "Result"
5220 ));
5221 }
5222
5223 #[test]
5224 fn infer_polymorphic_id_tuple() {
5225 let expr = parse_expr(
5226 r#"
5227 let
5228 id = \x -> x
5229 in
5230 id (id 420, id 6.9, id "str")
5231 "#,
5232 );
5233 let mut ts = TypeSystem::with_prelude().unwrap();
5234 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5235 let expected = Type::tuple(vec![
5236 Type::builtin(BuiltinTypeId::I32),
5237 Type::builtin(BuiltinTypeId::F32),
5238 Type::builtin(BuiltinTypeId::String),
5239 ]);
5240 assert_eq!(ty, expected);
5241 }
5242
5243 #[test]
5244 fn infer_type_annotation_ok() {
5245 let expr = parse_expr("let x: i32 = 42 in x");
5246 let mut ts = TypeSystem::with_prelude().unwrap();
5247 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5248 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5249 }
5250
5251 #[test]
5252 fn infer_type_annotation_lambda_param() {
5253 let expr = parse_expr("\\ (a : f32) -> a");
5254 let mut ts = TypeSystem::with_prelude().unwrap();
5255 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5256 assert_eq!(
5257 ty,
5258 Type::fun(
5259 Type::builtin(BuiltinTypeId::F32),
5260 Type::builtin(BuiltinTypeId::F32)
5261 )
5262 );
5263 }
5264
5265 #[test]
5266 fn infer_type_annotation_is_alias() {
5267 let expr = parse_expr("\"hi\" is str");
5268 let mut ts = TypeSystem::with_prelude().unwrap();
5269 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5270 assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
5271 }
5272
5273 #[test]
5274 fn infer_type_annotation_mismatch_error() {
5275 let expr = parse_expr("let x: i32 = 3.14 in x");
5276 let mut ts = TypeSystem::with_prelude().unwrap();
5277 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5278 assert!(matches!(err, TypeError::Unification(_, _)));
5279 }
5280
5281 #[test]
5282 fn infer_project_single_variant_let() {
5283 let program = parse_program(
5284 r#"
5285 type MyADT = MyVariant1 { field1: i32, field2: f32 }
5286 let
5287 x = MyVariant1 { field1 = 1, field2 = 2.0 }
5288 in
5289 (x.field1, x.field2)
5290 "#,
5291 );
5292 let mut ts = TypeSystem::with_prelude().unwrap();
5293 for decl in &program.decls {
5294 if let rexlang_ast::expr::Decl::Type(decl) = decl {
5295 ts.inject_type_decl(decl).unwrap();
5296 }
5297 }
5298 let (_preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5299 let expected = Type::tuple(vec![
5300 Type::builtin(BuiltinTypeId::I32),
5301 Type::builtin(BuiltinTypeId::F32),
5302 ]);
5303 assert_eq!(ty, expected);
5304 }
5305
5306 #[test]
5307 fn infer_project_known_variant_let() {
5308 let program = parse_program(
5309 r#"
5310 type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
5311 let
5312 x = MyVariant1 { field1 = 1, field2 = 2.0 }
5313 in
5314 x.field1
5315 "#,
5316 );
5317 let mut ts = TypeSystem::with_prelude().unwrap();
5318 for decl in &program.decls {
5319 if let rexlang_ast::expr::Decl::Type(decl) = decl {
5320 ts.inject_type_decl(decl).unwrap();
5321 }
5322 }
5323 let (_preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5324 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5325 }
5326
5327 #[test]
5328 fn infer_project_unknown_variant_error() {
5329 let program = parse_program(
5330 r#"
5331 type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
5332 let
5333 x = MyVariant2 1 2.0
5334 in
5335 x.field1
5336 "#,
5337 );
5338 let mut ts = TypeSystem::with_prelude().unwrap();
5339 for decl in &program.decls {
5340 if let rexlang_ast::expr::Decl::Type(decl) = decl {
5341 ts.inject_type_decl(decl).unwrap();
5342 }
5343 }
5344 let err = strip_span(ts.infer(program.expr.as_ref()).unwrap_err());
5345 assert!(matches!(err, TypeError::FieldNotKnown { .. }));
5346 }
5347
5348 #[test]
5349 fn infer_project_lambda_param_single_variant() {
5350 let program = parse_program(
5351 r#"
5352 type Boxed = Boxed { value: i32 }
5353 let
5354 f = \x -> x.value
5355 in
5356 f (Boxed { value = 1 })
5357 "#,
5358 );
5359 let mut ts = TypeSystem::with_prelude().unwrap();
5360 for decl in &program.decls {
5361 if let rexlang_ast::expr::Decl::Type(decl) = decl {
5362 ts.inject_type_decl(decl).unwrap();
5363 }
5364 }
5365 let (_preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5366 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5367 }
5368
5369 #[test]
5370 fn infer_project_in_match_arm() {
5371 let program = parse_program(
5372 r#"
5373 type MyADT = MyVariant1 { field1: i32 } | MyVariant2 i32
5374 let
5375 x = MyVariant1 { field1 = 1 }
5376 in
5377 match x
5378 when MyVariant1 { field1 } -> x.field1
5379 when MyVariant2 _ -> 0
5380 "#,
5381 );
5382 let mut ts = TypeSystem::with_prelude().unwrap();
5383 for decl in &program.decls {
5384 if let rexlang_ast::expr::Decl::Type(decl) = decl {
5385 ts.inject_type_decl(decl).unwrap();
5386 }
5387 }
5388 let (_preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5389 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5390 }
5391
5392 #[test]
5393 fn infer_nested_let_lambda_match_option() {
5394 let expr = parse_expr(
5395 r#"
5396 let
5397 choose = \flag a b -> if flag then a else b,
5398 build = \flag ->
5399 let
5400 pick = choose flag,
5401 val = pick 1 2
5402 in
5403 Some val
5404 in
5405 match (build true)
5406 when Some x -> x
5407 when None -> 0
5408 "#,
5409 );
5410 let mut ts = TypeSystem::with_prelude().unwrap();
5411 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5412 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5413 }
5414
5415 #[test]
5416 fn infer_polymorphic_apply_in_tuple() {
5417 let expr = parse_expr(
5418 r#"
5419 let
5420 apply = \f x -> f x,
5421 id = \x -> x,
5422 wrap = \x -> (x, x)
5423 in
5424 (apply id 1, apply id "hi", apply wrap true)
5425 "#,
5426 );
5427 let mut ts = TypeSystem::with_prelude().unwrap();
5428 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5429 let expected = Type::tuple(vec![
5430 Type::builtin(BuiltinTypeId::I32),
5431 Type::builtin(BuiltinTypeId::String),
5432 Type::tuple(vec![
5433 Type::builtin(BuiltinTypeId::Bool),
5434 Type::builtin(BuiltinTypeId::Bool),
5435 ]),
5436 ]);
5437 assert_eq!(ty, expected);
5438 }
5439
5440 #[test]
5441 fn infer_nested_result_option_match() {
5442 let expr = parse_expr(
5443 r#"
5444 let
5445 unwrap = \x ->
5446 match x
5447 when Ok (Some v) -> v
5448 when Ok None -> 0
5449 when Err _ -> 0
5450 in
5451 unwrap (Ok (Some 5))
5452 "#,
5453 );
5454 let mut ts = TypeSystem::with_prelude().unwrap();
5455 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5456 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5457 }
5458
5459 #[test]
5460 fn infer_head_or_list_match() {
5461 let expr = parse_expr(
5462 r#"
5463 let
5464 head_or = \fallback xs ->
5465 match xs
5466 when [] -> fallback
5467 when x::xs -> x
5468 in
5469 (head_or 0 [1, 2, 3], head_or 0 [])
5470 "#,
5471 );
5472 let mut ts = TypeSystem::with_prelude().unwrap();
5473 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5474 let expected = Type::tuple(vec![
5475 Type::builtin(BuiltinTypeId::I32),
5476 Type::builtin(BuiltinTypeId::I32),
5477 ]);
5478 assert_eq!(ty, expected);
5479 }
5480
5481 #[test]
5482 fn infer_head_or_list_match_cons_constructor_form() {
5483 let expr = parse_expr(
5484 r#"
5485 let
5486 head_or = \fallback xs ->
5487 match xs
5488 when [] -> fallback
5489 when Cons x xs1 -> x
5490 in
5491 (head_or 0 (Cons 1 (Cons 2 Empty)), head_or 0 Empty)
5492 "#,
5493 );
5494 let mut ts = TypeSystem::with_prelude().unwrap();
5495 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5496 let expected = Type::tuple(vec![
5497 Type::builtin(BuiltinTypeId::I32),
5498 Type::builtin(BuiltinTypeId::I32),
5499 ]);
5500 assert_eq!(ty, expected);
5501 }
5502
5503 #[test]
5504 fn infer_record_pattern_in_lambda() {
5505 let program = parse_program(
5506 r#"
5507 type Pair = Pair { left: i32, right: i32 }
5508 let
5509 sum = \p ->
5510 match p
5511 when Pair { left, right } -> left + right
5512 in
5513 sum (Pair { left = 1, right = 2 })
5514 "#,
5515 );
5516 let mut ts = TypeSystem::with_prelude().unwrap();
5517 for decl in &program.decls {
5518 if let rexlang_ast::expr::Decl::Type(decl) = decl {
5519 ts.inject_type_decl(decl).unwrap();
5520 }
5521 }
5522 let (_preds, ty) = ts.infer(program.expr.as_ref()).unwrap();
5523 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5524 }
5525
5526 #[test]
5527 fn infer_fn_decl_simple() {
5528 let program = parse_program(
5529 r#"
5530 fn add (x: i32, y: i32) -> i32 = x + y
5531 add 1 2
5532 "#,
5533 );
5534 let mut ts = TypeSystem::with_prelude().unwrap();
5535 let expr = program.expr_with_fns();
5536 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5537 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5538 }
5539
5540 #[test]
5541 fn infer_fn_decl_signature_form() {
5542 let program = parse_program(
5543 r#"
5544 fn add : i32 -> i32 -> i32 = \x y -> x + y
5545 add 1 2
5546 "#,
5547 );
5548 let mut ts = TypeSystem::with_prelude().unwrap();
5549 let expr = program.expr_with_fns();
5550 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5551 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5552 }
5553
5554 #[test]
5555 fn infer_fn_decl_polymorphic_where_constraints() {
5556 let program = parse_program(
5557 r#"
5558 fn my_add (x: a, y: a) -> a where AdditiveMonoid a = x + y
5559 (my_add 1 2, my_add 1.0 2.0)
5560 "#,
5561 );
5562 let mut ts = TypeSystem::with_prelude().unwrap();
5563 let expr = program.expr_with_fns();
5564 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
5565 assert_eq!(
5566 ty,
5567 Type::tuple(vec![
5568 Type::builtin(BuiltinTypeId::I32),
5569 Type::builtin(BuiltinTypeId::F32)
5570 ])
5571 );
5572 }
5573
5574 #[test]
5575 fn infer_additive_monoid_constraint() {
5576 let expr = parse_expr("\\x y -> x + y");
5577 let mut ts = TypeSystem::with_prelude().unwrap();
5578 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5579 assert_eq!(preds.len(), 1);
5580 assert_eq!(preds[0].class.as_ref(), "AdditiveMonoid");
5581
5582 if let TypeKind::Fun(a, rest) = ty.as_ref()
5583 && let TypeKind::Fun(b, c) = rest.as_ref()
5584 {
5585 assert_eq!(a.as_ref(), b.as_ref());
5586 assert_eq!(b.as_ref(), c.as_ref());
5587 assert_eq!(preds[0].typ, a.clone());
5588 return;
5589 }
5590 panic!("expected a -> a -> a");
5591 }
5592
5593 #[test]
5594 fn infer_multiplicative_monoid_constraint() {
5595 let expr = parse_expr("\\x y -> x * y");
5596 let mut ts = TypeSystem::with_prelude().unwrap();
5597 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5598 assert_eq!(preds.len(), 1);
5599 assert_eq!(preds[0].class.as_ref(), "MultiplicativeMonoid");
5600
5601 if let TypeKind::Fun(a, rest) = ty.as_ref()
5602 && let TypeKind::Fun(b, c) = rest.as_ref()
5603 {
5604 assert_eq!(a.as_ref(), b.as_ref());
5605 assert_eq!(b.as_ref(), c.as_ref());
5606 assert_eq!(preds[0].typ, a.clone());
5607 return;
5608 }
5609 panic!("expected a -> a -> a");
5610 }
5611
5612 #[test]
5613 fn infer_additive_group_constraint() {
5614 let expr = parse_expr("\\x y -> x - y");
5615 let mut ts = TypeSystem::with_prelude().unwrap();
5616 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5617 assert_eq!(preds.len(), 1);
5618 assert_eq!(preds[0].class.as_ref(), "AdditiveGroup");
5619
5620 if let TypeKind::Fun(a, rest) = ty.as_ref()
5621 && let TypeKind::Fun(b, c) = rest.as_ref()
5622 {
5623 assert_eq!(a.as_ref(), b.as_ref());
5624 assert_eq!(b.as_ref(), c.as_ref());
5625 assert_eq!(preds[0].typ, a.clone());
5626 return;
5627 }
5628 panic!("expected a -> a -> a");
5629 }
5630
5631 #[test]
5632 fn infer_integral_constraint() {
5633 let expr = parse_expr("\\x y -> x % y");
5634 let mut ts = TypeSystem::with_prelude().unwrap();
5635 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5636 assert_eq!(preds.len(), 1);
5637 assert_eq!(preds[0].class.as_ref(), "Integral");
5638
5639 if let TypeKind::Fun(a, rest) = ty.as_ref()
5640 && let TypeKind::Fun(b, c) = rest.as_ref()
5641 {
5642 assert_eq!(a.as_ref(), b.as_ref());
5643 assert_eq!(b.as_ref(), c.as_ref());
5644 assert_eq!(preds[0].typ, a.clone());
5645 return;
5646 }
5647 panic!("expected a -> a -> a");
5648 }
5649
5650 #[test]
5651 fn infer_literal_addition_defaults() {
5652 let expr = parse_expr("1 + 2");
5653 let mut ts = TypeSystem::with_prelude().unwrap();
5654 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5655 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5656 assert_eq!(preds.len(), 2);
5657 assert!(preds.iter().any(|p| p.class.as_ref() == "AdditiveMonoid"));
5658 assert!(preds.iter().any(|p| p.class.as_ref() == "Integral"));
5659 assert!(
5660 preds
5661 .iter()
5662 .all(|p| p.typ == Type::builtin(BuiltinTypeId::I32))
5663 );
5664 }
5665
5666 #[test]
5667 fn infer_mod_defaults() {
5668 let expr = parse_expr("1 % 2");
5669 let mut ts = TypeSystem::with_prelude().unwrap();
5670 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5671 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5672 assert_eq!(preds.len(), 1);
5673 assert_eq!(preds[0].class.as_ref(), "Integral");
5674 assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::I32));
5675 }
5676
5677 #[test]
5678 fn infer_get_list_type() {
5679 let expr = parse_expr("get 1 [1, 2, 3]");
5680 let mut ts = TypeSystem::with_prelude().unwrap();
5681 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5682 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5683 assert!(preds.iter().any(|p| p.class.as_ref() == "Indexable"));
5684 assert!(preds.iter().all(|p| {
5685 p.class.as_ref() == "Indexable"
5686 || (p.class.as_ref() == "Integral" && p.typ == Type::builtin(BuiltinTypeId::I32))
5687 }));
5688 for pred in preds.iter().filter(|p| p.class.as_ref() == "Indexable") {
5689 assert!(entails(&ts.classes, &[], pred).unwrap());
5690 }
5691 }
5692
5693 #[test]
5694 fn infer_get_tuple_type() {
5695 let expr = parse_expr("(1, 'Hello', true).0");
5696 let mut ts = TypeSystem::with_prelude().unwrap();
5697 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5698 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
5699 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
5700
5701 let expr = parse_expr("(1, 'Hello', true).1");
5702 let mut ts = TypeSystem::with_prelude().unwrap();
5703 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5704 assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
5705 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
5706
5707 let expr = parse_expr("(1, 'Hello', true).2");
5708 let mut ts = TypeSystem::with_prelude().unwrap();
5709 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5710 assert_eq!(ty, Type::builtin(BuiltinTypeId::Bool));
5711 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
5712 }
5713
5714 #[test]
5715 fn infer_division_defaults() {
5716 let expr = parse_expr("1.0 / 2.0");
5717 let mut ts = TypeSystem::with_prelude().unwrap();
5718 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
5719 assert_eq!(ty, Type::builtin(BuiltinTypeId::F32));
5720 assert_eq!(preds.len(), 1);
5721 assert_eq!(preds[0].class.as_ref(), "Field");
5722 assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::F32));
5723 assert!(entails(&ts.classes, &[], &preds[0]).unwrap());
5724 }
5725
5726 #[test]
5727 fn infer_unbound_variable_error() {
5728 let expr = parse_expr("missing");
5729 let mut ts = TypeSystem::with_prelude().unwrap();
5730 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5731 assert!(matches!(
5732 err,
5733 TypeError::UnknownVar(name) if name.as_ref() == "missing"
5734 ));
5735 }
5736
5737 #[test]
5738 fn infer_if_branch_type_mismatch_error() {
5739 let expr = parse_expr(r#"if true then 1 else "no""#);
5740 let mut ts = TypeSystem::with_prelude().unwrap();
5741 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5742 match err {
5743 TypeError::Unification(a, b) => {
5744 let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
5745 assert!(ok, "expected i32 vs string, got {a} vs {b}");
5746 }
5747 other => panic!("expected unification error, got {other:?}"),
5748 }
5749 }
5750
5751 #[test]
5752 fn infer_unknown_pattern_constructor_error() {
5753 let expr = parse_expr("match 1 when Nope -> 1");
5754 let mut ts = TypeSystem::with_prelude().unwrap();
5755 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5756 assert!(matches!(
5757 err,
5758 TypeError::UnknownVar(name) if name.as_ref() == "Nope"
5759 ));
5760 }
5761
5762 #[test]
5763 fn infer_ambiguous_overload_error() {
5764 let mut ts = TypeSystem::new();
5765 let a = TypeVar::new(0, Some(sym("a")));
5766 let b = TypeVar::new(1, Some(sym("b")));
5767 let scheme_a = Scheme::new(vec![a.clone()], vec![], Type::var(a));
5768 let scheme_b = Scheme::new(vec![b.clone()], vec![], Type::var(b));
5769 ts.add_overload(sym("dup"), scheme_a);
5770 ts.add_overload(sym("dup"), scheme_b);
5771 let expr = parse_expr("dup");
5772 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5773 assert!(matches!(
5774 err,
5775 TypeError::AmbiguousOverload(name) if name.as_ref() == "dup"
5776 ));
5777 }
5778
5779 #[test]
5780 fn infer_if_cond_not_bool_error() {
5781 let expr = parse_expr("if 1 then 2 else 3");
5782 let mut ts = TypeSystem::with_prelude().unwrap();
5783 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5784 match err {
5785 TypeError::Unification(a, b) => {
5786 let ok = (a == "bool" && b == "i32") || (a == "i32" && b == "bool");
5787 assert!(ok, "expected bool vs i32, got {a} vs {b}");
5788 }
5789 other => panic!("expected unification error, got {other:?}"),
5790 }
5791 }
5792
5793 #[test]
5794 fn infer_apply_non_function_error() {
5795 let expr = parse_expr("1 2");
5796 let mut ts = TypeSystem::with_prelude().unwrap();
5797 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5798 assert!(matches!(err, TypeError::Unification(_, _)));
5799 }
5800
5801 #[test]
5802 fn infer_list_element_mismatch_error() {
5803 let expr = parse_expr("[1, true]");
5804 let mut ts = TypeSystem::with_prelude().unwrap();
5805 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5806 match err {
5807 TypeError::Unification(a, b) => {
5808 let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
5809 assert!(ok, "expected i32 vs bool, got {a} vs {b}");
5810 }
5811 other => panic!("expected unification error, got {other:?}"),
5812 }
5813 }
5814
5815 #[test]
5816 fn infer_dict_value_mismatch_error() {
5817 let expr = parse_expr("{a = 1, b = true}");
5818 let mut ts = TypeSystem::with_prelude().unwrap();
5819 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5820 match err {
5821 TypeError::Unification(a, b) => {
5822 let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
5823 assert!(ok, "expected i32 vs bool, got {a} vs {b}");
5824 }
5825 other => panic!("expected unification error, got {other:?}"),
5826 }
5827 }
5828
5829 #[test]
5830 fn infer_match_list_on_non_list_error() {
5831 let expr = parse_expr("match 1 when [x] -> x");
5832 let mut ts = TypeSystem::with_prelude().unwrap();
5833 assert!(ts.infer(expr.as_ref()).is_err());
5834 }
5835
5836 #[test]
5837 fn infer_pattern_constructor_arity_error() {
5838 let expr = parse_expr("match (Ok 1) when Ok x y -> x");
5839 let mut ts = TypeSystem::with_prelude().unwrap();
5840 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5841 assert!(matches!(
5842 err,
5843 TypeError::UnsupportedExpr("pattern constructor")
5844 ));
5845 }
5846
5847 #[test]
5848 fn infer_match_arm_type_mismatch_error() {
5849 let expr = parse_expr(r#"match 1 when _ -> 1 when _ -> "no""#);
5850 let mut ts = TypeSystem::with_prelude().unwrap();
5851 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5852 match err {
5853 TypeError::Unification(a, b) => {
5854 let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
5855 assert!(ok, "expected i32 vs string, got {a} vs {b}");
5856 }
5857 other => panic!("expected unification error, got {other:?}"),
5858 }
5859 }
5860
5861 #[test]
5862 fn infer_match_option_on_non_option_error() {
5863 let expr = parse_expr("match 1 when Some x -> x");
5864 let mut ts = TypeSystem::with_prelude().unwrap();
5865 assert!(ts.infer(expr.as_ref()).is_err());
5866 }
5867
5868 #[test]
5869 fn infer_dict_pattern_on_non_dict_error() {
5870 let expr = parse_expr("match 1 when {a} -> a");
5871 let mut ts = TypeSystem::with_prelude().unwrap();
5872 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5873 assert!(matches!(err, TypeError::Unification(_, _)));
5874 }
5875
5876 #[test]
5877 fn infer_cons_pattern_on_non_list_error() {
5878 let expr = parse_expr("match 1 when x::xs -> x");
5879 let mut ts = TypeSystem::with_prelude().unwrap();
5880 assert!(ts.infer(expr.as_ref()).is_err());
5881 }
5882
5883 #[test]
5884 fn infer_apply_wrong_arg_type_error() {
5885 let expr = parse_expr("(\\x -> x + 1) true");
5886 let mut ts = TypeSystem::with_prelude().unwrap();
5887 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5888 assert!(matches!(err, TypeError::Unification(_, _)));
5889 }
5890
5891 #[test]
5892 fn infer_self_application_occurs_error() {
5893 let expr = parse_expr("\\x -> x x");
5894 let mut ts = TypeSystem::with_prelude().unwrap();
5895 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5896 assert!(matches!(err, TypeError::Occurs(_, _)));
5897 }
5898
5899 #[test]
5900 fn infer_apply_constructor_too_many_args_error() {
5901 let expr = parse_expr("Some 1 2");
5902 let mut ts = TypeSystem::with_prelude().unwrap();
5903 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5904 assert!(matches!(err, TypeError::Unification(_, _)));
5905 }
5906
5907 #[test]
5908 fn infer_operator_type_mismatch_error() {
5909 let expr = parse_expr("1 + true");
5910 let mut ts = TypeSystem::with_prelude().unwrap();
5911 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5912 assert!(matches!(err, TypeError::Unification(_, _)));
5913 }
5914
5915 #[test]
5916 fn infer_non_exhaustive_match_is_error() {
5917 let expr = parse_expr("match (Ok 1) when Ok x -> x");
5918 let mut ts = TypeSystem::with_prelude().unwrap();
5919 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5920 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
5921 }
5922
5923 #[test]
5924 fn infer_non_exhaustive_match_on_bound_var_error() {
5925 let expr = parse_expr("let x = Ok 1 in match x when Ok y -> y");
5926 let mut ts = TypeSystem::with_prelude().unwrap();
5927 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5928 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
5929 }
5930
5931 #[test]
5932 fn infer_non_exhaustive_match_in_lambda_error() {
5933 let expr = parse_expr("\\x -> match x when Ok y -> y");
5934 let mut ts = TypeSystem::with_prelude().unwrap();
5935 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5936 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
5937 }
5938
5939 #[test]
5940 fn infer_non_exhaustive_option_match_error() {
5941 let expr = parse_expr("match (Some 1) when Some x -> x");
5942 let mut ts = TypeSystem::with_prelude().unwrap();
5943 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5944 match err {
5945 TypeError::NonExhaustiveMatch { missing, .. } => {
5946 assert_eq!(missing, vec![sym("None")]);
5947 }
5948 other => panic!("expected non-exhaustive match, got {other:?}"),
5949 }
5950 }
5951
5952 #[test]
5953 fn infer_non_exhaustive_result_match_error() {
5954 let expr = parse_expr("match (Err 1) when Ok x -> x");
5955 let mut ts = TypeSystem::with_prelude().unwrap();
5956 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5957 match err {
5958 TypeError::NonExhaustiveMatch { missing, .. } => {
5959 assert_eq!(missing, vec![sym("Err")]);
5960 }
5961 other => panic!("expected non-exhaustive match, got {other:?}"),
5962 }
5963 }
5964
5965 #[test]
5966 fn infer_non_exhaustive_list_missing_empty_error() {
5967 let expr = parse_expr("match [1, 2] when x::xs -> x");
5968 let mut ts = TypeSystem::with_prelude().unwrap();
5969 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5970 match err {
5971 TypeError::NonExhaustiveMatch { missing, .. } => {
5972 assert_eq!(missing, vec![sym("Empty")]);
5973 }
5974 other => panic!("expected non-exhaustive match, got {other:?}"),
5975 }
5976 }
5977
5978 #[test]
5979 fn infer_non_exhaustive_list_match_on_bound_var_error() {
5980 let expr = parse_expr("let xs = [1, 2] in match xs when x::xs -> x");
5981 let mut ts = TypeSystem::with_prelude().unwrap();
5982 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5983 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
5984 }
5985
5986 #[test]
5987 fn infer_non_exhaustive_list_missing_cons_error() {
5988 let expr = parse_expr("match [1] when [] -> 0");
5989 let mut ts = TypeSystem::with_prelude().unwrap();
5990 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
5991 match err {
5992 TypeError::NonExhaustiveMatch { missing, .. } => {
5993 assert_eq!(missing, vec![sym("Cons")]);
5994 }
5995 other => panic!("expected non-exhaustive match, got {other:?}"),
5996 }
5997 }
5998
5999 #[test]
6000 fn infer_match_list_patterns_on_result_error() {
6001 let expr = parse_expr("match (Ok 1) when [] -> 0 when x::xs -> 1");
6002 let mut ts = TypeSystem::with_prelude().unwrap();
6003 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
6004 assert!(matches!(err, TypeError::Unification(_, _)));
6005 }
6006
6007 #[test]
6008 fn infer_missing_instances_produce_unsatisfied_predicates() {
6009 for (name, code) in [
6010 ("division", "1 / 2"),
6011 ("eq_dict", "{a = 1} == {a = 2}"),
6012 ("min_bool", "min [true]"),
6013 ("map_dict", r#"map (\x -> x) {a = 1}"#),
6014 ] {
6015 let (class, pred_type, expected_ty) = match name {
6016 "division" => (
6017 "Field",
6018 Type::builtin(BuiltinTypeId::I32),
6019 Some(Type::builtin(BuiltinTypeId::I32)),
6020 ),
6021 "eq_dict" => ("Eq", dict_of(Type::builtin(BuiltinTypeId::I32)), None),
6022 "min_bool" => ("Ord", Type::builtin(BuiltinTypeId::Bool), None),
6023 "map_dict" => ("Functor", Type::builtin(BuiltinTypeId::Dict), None),
6024 _ => unreachable!("unknown test case {name}"),
6025 };
6026
6027 let expr = parse_expr(code);
6028 let mut ts = TypeSystem::with_prelude().unwrap();
6029 let (preds, ty) = ts.infer(expr.as_ref()).unwrap();
6030 if let Some(expected) = expected_ty {
6031 assert_eq!(ty, expected, "{name}");
6032 }
6033
6034 let pred = preds
6035 .iter()
6036 .find(|p| p.class.as_ref() == class && p.typ == pred_type)
6037 .unwrap();
6038 assert!(!entails(&ts.classes, &[], pred).unwrap(), "{name}");
6039 }
6040 }
6041
6042 #[test]
6043 fn record_update_single_variant_adt_infers() {
6044 let program = parse_program(
6045 r#"
6046 type Foo = Bar { x: i32, y: i32 }
6047 let
6048 foo: Foo = Bar { x = 1, y = 2 },
6049 bar = { foo with { x = 3 } }
6050 in
6051 bar
6052 "#,
6053 );
6054 let mut ts = TypeSystem::with_prelude().unwrap();
6055 ts.inject_decls(&program.decls).unwrap();
6056 let (_preds, typ) = ts.infer(program.expr.as_ref()).unwrap();
6057 assert_eq!(typ.to_string(), "Foo");
6058 }
6059
6060 #[test]
6061 fn record_update_unknown_field_errors() {
6062 let program = parse_program(
6063 r#"
6064 type Foo = Bar { x: i32 }
6065 let
6066 foo: Foo = Bar { x = 1 }
6067 in
6068 { foo with { y = 2 } }
6069 "#,
6070 );
6071 let mut ts = TypeSystem::with_prelude().unwrap();
6072 ts.inject_decls(&program.decls).unwrap();
6073 let err = ts.infer(program.expr.as_ref()).unwrap_err();
6074 let err = strip_span(err);
6075 assert!(matches!(err, TypeError::UnknownField { .. }));
6076 }
6077
6078 #[test]
6079 fn record_update_requires_refined_variant_for_sum_types() {
6080 let program = parse_program(
6081 r#"
6082 type Foo = Bar { x: i32 } | Baz { x: i32 }
6083 let
6084 f = \ (foo : Foo) -> { foo with { x = 2 } }
6085 in
6086 f (Bar { x = 1 })
6087 "#,
6088 );
6089 let mut ts = TypeSystem::with_prelude().unwrap();
6090 ts.inject_decls(&program.decls).unwrap();
6091 let err = ts.infer(program.expr.as_ref()).unwrap_err();
6092 let err = strip_span(err);
6093 assert!(matches!(err, TypeError::FieldNotKnown { .. }));
6094 }
6095
6096 #[test]
6097 fn record_update_allowed_after_match_refines_variant() {
6098 let program = parse_program(
6099 r#"
6100 type Foo = Bar { x: i32 } | Baz { x: i32 }
6101 let
6102 f = \ (foo : Foo) ->
6103 match foo
6104 when Bar {x} -> { foo with { x = x + 1 } }
6105 when Baz {x} -> { foo with { x = x + 2 } }
6106 in
6107 f (Bar { x = 1 })
6108 "#,
6109 );
6110 let mut ts = TypeSystem::with_prelude().unwrap();
6111 ts.inject_decls(&program.decls).unwrap();
6112 let (_preds, typ) = ts.infer(program.expr.as_ref()).unwrap();
6113 assert_eq!(typ.to_string(), "Foo");
6114 }
6115
6116 #[test]
6117 fn record_update_plain_record_type() {
6118 let program = parse_program(
6119 r#"
6120 let
6121 f = \ (r : { x: i32, y: i32 }) -> { r with { y = 9 } }
6122 in
6123 f { x = 1, y = 2 }
6124 "#,
6125 );
6126 let mut ts = TypeSystem::with_prelude().unwrap();
6127 ts.inject_decls(&program.decls).unwrap();
6128 let (_preds, typ) = ts.infer(program.expr.as_ref()).unwrap();
6129 assert_eq!(typ.to_string(), "{x: i32, y: i32}");
6130 }
6131
6132 #[test]
6133 fn infer_typed_hole_expr_is_hole_kind() {
6134 let expr = parse_expr("?");
6135 let mut ts = TypeSystem::with_prelude().unwrap();
6136 let (typed, _preds, _ty) = ts.infer_typed(expr.as_ref()).unwrap();
6137 assert!(
6138 matches!(typed.kind, TypedExprKind::Hole),
6139 "typed={typed:#?}"
6140 );
6141 }
6142
6143 #[test]
6144 fn infer_hole_with_annotation_unifies_to_annotation() {
6145 let expr = parse_expr("let x : i32 = ? in x");
6146 let mut ts = TypeSystem::with_prelude().unwrap();
6147 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
6148 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
6149 }
6150
6151 #[test]
6152 fn infer_hole_in_if_condition_is_bool_constrained() {
6153 let expr = parse_expr("if ? then 1 else 2");
6154 let mut ts = TypeSystem::with_prelude().unwrap();
6155 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
6156 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
6157 }
6158
6159 #[test]
6160 fn infer_hole_in_arithmetic_is_numeric_constrained() {
6161 let expr = parse_expr("? + 1");
6162 let mut ts = TypeSystem::with_prelude().unwrap();
6163 let (_preds, ty) = ts.infer(expr.as_ref()).unwrap();
6164 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
6165 }
6166
6167 #[test]
6168 fn infer_hole_arithmetic_conflicting_annotation_failure() {
6169 let expr = parse_expr("let x : string = (? + 1) in x");
6170 let mut ts = TypeSystem::with_prelude().unwrap();
6171 let err = strip_span(ts.infer(expr.as_ref()).unwrap_err());
6172 assert!(matches!(err, TypeError::Unification(_, _)), "err={err:#?}");
6173 }
6174}