1use crate::{
2 error::TypeError,
3 types::{
4 AdtDecl, AdtVariant, BuiltinTypeId, Predicate, Scheme, Type, TypeConst, TypeEnv, TypeKind,
5 TypeVar, TypeVarId, TypedExpr, TypedExprKind, Types,
6 },
7 typesystem::{
8 TypeSystem, TypeVarSupply, instantiate, is_integral_literal_expr,
9 predicates_from_constraints, reject_ambiguous_scheme, type_from_annotation_expr,
10 type_from_annotation_expr_vars,
11 },
12 unification::{Subst, Unifier, compose_subst, subst_is_empty, unify},
13};
14use rex_ast::expr::{Expr, Pattern, Symbol, TypeConstraint, TypeExpr, sym};
15use rex_util::gas::GasMeter;
16use std::{
17 collections::{BTreeMap, BTreeSet},
18 sync::Arc,
19};
20
21fn dedup_preds(preds: Vec<Predicate>) -> Vec<Predicate> {
22 let mut seen = BTreeSet::new();
23 let mut out = Vec::with_capacity(preds.len());
24 for pred in preds {
25 if seen.insert(pred.clone()) {
26 out.push(pred);
27 }
28 }
29 out
30}
31
32fn is_integral_primitive(typ: &Type) -> bool {
33 matches!(
34 typ.as_ref(),
35 TypeKind::Con(TypeConst {
36 builtin_id: Some(
37 BuiltinTypeId::U8
38 | BuiltinTypeId::U16
39 | BuiltinTypeId::U32
40 | BuiltinTypeId::U64
41 | BuiltinTypeId::I8
42 | BuiltinTypeId::I16
43 | BuiltinTypeId::I32
44 | BuiltinTypeId::I64
45 ),
46 ..
47 })
48 )
49}
50
51fn finalize_infer_for_public_api(
52 mut preds: Vec<Predicate>,
53 mut typ: Type,
54) -> Result<(Vec<Predicate>, Type), TypeError> {
55 let mut subst = Subst::new_sync();
56 for pred in &preds {
57 if pred.class.as_ref() == "Integral"
58 && let TypeKind::Var(tv) = pred.typ.as_ref()
59 {
60 subst = subst.insert(tv.id, Type::builtin(BuiltinTypeId::I32));
61 }
62 }
63
64 if !subst_is_empty(&subst) {
65 preds = dedup_preds(preds.apply(&subst));
66 typ = typ.apply(&subst);
67 }
68
69 for pred in &preds {
70 if pred.class.as_ref() != "Integral" {
71 continue;
72 }
73 if matches!(pred.typ.as_ref(), TypeKind::Var(_)) || is_integral_primitive(&pred.typ) {
74 continue;
75 }
76 return Err(TypeError::Unification("i32".into(), pred.typ.to_string()));
77 }
78
79 Ok((preds, typ))
80}
81
82#[derive(Clone, Debug)]
83struct KnownVariant {
84 adt: Symbol,
85 variant: Symbol,
86}
87
88type KnownVariants = BTreeMap<Symbol, KnownVariant>;
89
90fn apply_scheme_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> Scheme {
91 let preds = scheme
92 .preds
93 .iter()
94 .map(|pred| Predicate::new(pred.class.clone(), unifier.apply_type(&pred.typ)))
95 .collect();
96 let typ = unifier.apply_type(&scheme.typ);
97 Scheme::new(scheme.vars.clone(), preds, typ)
98}
99
100fn scheme_ftv_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> BTreeSet<TypeVarId> {
101 let mut ftv = unifier.apply_type(&scheme.typ).ftv();
102 for pred in &scheme.preds {
103 ftv.extend(unifier.apply_type(&pred.typ).ftv());
104 }
105 for var in &scheme.vars {
106 ftv.remove(&var.id);
107 }
108 ftv
109}
110
111fn env_ftv_with_unifier(env: &TypeEnv, unifier: &mut Unifier<'_>) -> BTreeSet<TypeVarId> {
112 let mut out = BTreeSet::new();
113 for (_name, schemes) in env.values.iter() {
114 for scheme in schemes {
115 out.extend(scheme_ftv_with_unifier(scheme, unifier));
116 }
117 }
118 out
119}
120
121fn generalize_with_unifier(
122 env: &TypeEnv,
123 preds: Vec<Predicate>,
124 typ: Type,
125 unifier: &mut Unifier<'_>,
126) -> Scheme {
127 let preds: Vec<Predicate> = preds
128 .into_iter()
129 .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
130 .collect();
131 let typ = unifier.apply_type(&typ);
132 let mut vars: Vec<TypeVar> = typ
133 .ftv()
134 .union(&preds.ftv())
135 .copied()
136 .collect::<BTreeSet<_>>()
137 .difference(&env_ftv_with_unifier(env, unifier))
138 .cloned()
139 .map(|id| TypeVar::new(id, None))
140 .collect();
141 vars.sort_by_key(|v| v.id);
142 Scheme::new(vars, preds, typ)
143}
144
145fn monomorphic_scheme_with_unifier(
146 preds: Vec<Predicate>,
147 typ: Type,
148 unifier: &mut Unifier<'_>,
149) -> Scheme {
150 let preds = dedup_preds(
151 preds
152 .into_iter()
153 .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
154 .collect(),
155 );
156 let typ = unifier.apply_type(&typ);
157 Scheme::new(vec![], preds, typ)
158}
159
160pub fn infer_typed(
161 type_system: &mut TypeSystem,
162 expr: &Expr,
163) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
164 infer_typed_inner(type_system, expr)
165}
166
167pub fn infer_typed_with_gas(
168 type_system: &mut TypeSystem,
169 expr: &Expr,
170 gas: &mut GasMeter,
171) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
172 let known = KnownVariants::new();
173 let mut unifier = Unifier::with_gas(gas, type_system.limits.max_infer_depth);
174 let (preds, t, typed) = infer_expr(
175 &mut unifier,
176 &mut type_system.supply,
177 &type_system.env,
178 &type_system.adts,
179 &known,
180 expr,
181 )
182 .map_err(|err| err.with_span(expr.span()))?;
183 let subst = unifier.into_subst();
184 let mut typed = typed.apply(&subst);
185 let mut preds = dedup_preds(preds.apply(&subst));
186 let mut t = t.apply(&subst);
187 let improve = improve_indexable(&preds)?;
188 if !subst_is_empty(&improve) {
189 typed = typed.apply(&improve);
190 preds = dedup_preds(preds.apply(&improve));
191 t = t.apply(&improve);
192 }
193 type_system.check_predicate_kinds(&preds)?;
194 Ok((typed, preds, t))
195}
196
197fn infer_typed_inner(
198 type_system: &mut TypeSystem,
199 expr: &Expr,
200) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
201 let known = KnownVariants::new();
202 let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
203 let (preds, t, typed) = infer_expr(
204 &mut unifier,
205 &mut type_system.supply,
206 &type_system.env,
207 &type_system.adts,
208 &known,
209 expr,
210 )
211 .map_err(|err| err.with_span(expr.span()))?;
212 let subst = unifier.into_subst();
213 let mut typed = typed.apply(&subst);
214 let mut preds = dedup_preds(preds.apply(&subst));
215 let mut t = t.apply(&subst);
216 let improve = improve_indexable(&preds)?;
217 if !subst_is_empty(&improve) {
218 typed = typed.apply(&improve);
219 preds = dedup_preds(preds.apply(&improve));
220 t = t.apply(&improve);
221 }
222 type_system.check_predicate_kinds(&preds)?;
223 Ok((typed, preds, t))
224}
225
226pub fn infer(
227 type_system: &mut TypeSystem,
228 expr: &Expr,
229) -> Result<(Vec<Predicate>, Type), TypeError> {
230 infer_inner(type_system, expr)
231}
232
233pub fn infer_with_gas(
234 type_system: &mut TypeSystem,
235 expr: &Expr,
236 gas: &mut GasMeter,
237) -> Result<(Vec<Predicate>, Type), TypeError> {
238 let known = KnownVariants::new();
239 let mut unifier = Unifier::with_gas(gas, type_system.limits.max_infer_depth);
240 let (preds, t) = infer_expr_type(
241 &mut unifier,
242 &mut type_system.supply,
243 &type_system.env,
244 &type_system.adts,
245 &known,
246 expr,
247 )
248 .map_err(|err| err.with_span(expr.span()))?;
249 let subst = unifier.into_subst();
250 let preds = dedup_preds(preds.apply(&subst));
251 let t = t.apply(&subst);
252 type_system.check_predicate_kinds(&preds)?;
253 finalize_infer_for_public_api(preds, t)
254}
255
256fn infer_inner(
257 type_system: &mut TypeSystem,
258 expr: &Expr,
259) -> Result<(Vec<Predicate>, Type), TypeError> {
260 let known = KnownVariants::new();
261 let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
262 let (preds, t) = infer_expr_type(
263 &mut unifier,
264 &mut type_system.supply,
265 &type_system.env,
266 &type_system.adts,
267 &known,
268 expr,
269 )
270 .map_err(|err| err.with_span(expr.span()))?;
271 let subst = unifier.into_subst();
272 let mut preds = dedup_preds(preds.apply(&subst));
273 let mut t = t.apply(&subst);
274 let improve = improve_indexable(&preds)?;
275 if !subst_is_empty(&improve) {
276 preds = dedup_preds(preds.apply(&improve));
277 t = t.apply(&improve);
278 }
279 type_system.check_predicate_kinds(&preds)?;
280 finalize_infer_for_public_api(preds, t)
281}
282
283fn improve_indexable(preds: &[Predicate]) -> Result<Subst, TypeError> {
284 let mut subst = Subst::new_sync();
285 loop {
286 let mut changed = false;
287 for pred in preds {
288 let pred = pred.apply(&subst);
289 if pred.class.as_ref() != "Indexable" {
290 continue;
291 }
292 let TypeKind::Tuple(parts) = pred.typ.as_ref() else {
293 continue;
294 };
295 if parts.len() != 2 {
296 continue;
297 }
298 let container = parts[0].clone();
299 let elem = parts[1].clone();
300 let s = indexable_elem_subst(&container, &elem)?;
301 if !subst_is_empty(&s) {
302 subst = compose_subst(s, subst);
303 changed = true;
304 }
305 }
306 if !changed {
307 break;
308 }
309 }
310 Ok(subst)
311}
312
313fn indexable_elem_subst(container: &Type, elem: &Type) -> Result<Subst, TypeError> {
314 match container.as_ref() {
315 TypeKind::App(head, arg) => match head.as_ref() {
316 TypeKind::Con(tc)
317 if matches!(
318 tc.builtin_id,
319 Some(BuiltinTypeId::List | BuiltinTypeId::Array)
320 ) =>
321 {
322 unify(elem, arg)
323 }
324 _ => Ok(Subst::new_sync()),
325 },
326 TypeKind::Tuple(elems) => {
327 if elems.is_empty() {
328 return Ok(Subst::new_sync());
329 }
330 let mut subst = Subst::new_sync();
331 let mut cur = elems[0].clone();
332 for ty in elems.iter().skip(1) {
333 let s_next = unify(&cur.apply(&subst), &ty.apply(&subst))?;
334 subst = compose_subst(s_next, subst);
335 cur = cur.apply(&subst);
336 }
337 let elem = elem.apply(&subst);
338 let s_elem = unify(&elem, &cur.apply(&subst))?;
339 Ok(compose_subst(s_elem, subst))
340 }
341 _ => Ok(Subst::new_sync()),
342 }
343}
344
345type LambdaChain<'a> = (
346 Vec<(Symbol, Option<TypeExpr>)>,
347 Vec<TypeConstraint>,
348 &'a Expr,
349);
350
351fn collect_lambda_chain<'a>(expr: &'a Expr) -> LambdaChain<'a> {
352 let mut params = Vec::new();
353 let mut constraints = Vec::new();
354 let mut cur = expr;
355 let mut seen_constraints = false;
356 while let Expr::Lam(_, _scope, param, ann, lam_constraints, body) = cur {
357 if !lam_constraints.is_empty() {
358 if seen_constraints {
359 break;
360 }
361 constraints = lam_constraints.clone();
362 seen_constraints = true;
363 }
364 params.push((param.name.clone(), ann.clone()));
365 cur = body.as_ref();
366 }
367 (params, constraints, cur)
368}
369
370fn collect_app_chain(expr: &Expr) -> (&Expr, Vec<&Expr>) {
371 let mut args = Vec::new();
372 let mut cur = expr;
373 while let Expr::App(_, f, x) = cur {
374 args.push(x.as_ref());
375 cur = f.as_ref();
376 }
377 args.reverse();
378 (cur, args)
379}
380
381fn narrow_overload_candidates(candidates: &[Type], arg_ty: &Type) -> Vec<Type> {
382 let mut out = Vec::new();
383 for candidate in candidates {
384 let Some((params, ret)) = decompose_fun(candidate, 1) else {
385 continue;
386 };
387 let param = ¶ms[0];
388 if let Ok(s) = unify(param, arg_ty) {
389 out.push(ret.apply(&s));
390 }
391 }
392 out
393}
394
395fn unary_app_arg(typ: &Type, ctor_name: &str) -> Option<Type> {
396 let TypeKind::App(head, arg) = typ.as_ref() else {
397 return None;
398 };
399 let TypeKind::Con(tc) = head.as_ref() else {
400 return None;
401 };
402 (tc.name.as_ref() == ctor_name && tc.arity == 1).then(|| arg.clone())
403}
404
405fn infer_app_arg_type(
406 unifier: &mut Unifier<'_>,
407 supply: &mut TypeVarSupply,
408 env: &TypeEnv,
409 adts: &BTreeMap<Symbol, AdtDecl>,
410 known: &KnownVariants,
411 arg_hint: Option<Type>,
412 arg: &Expr,
413) -> Result<(Vec<Predicate>, Type), TypeError> {
414 match (arg_hint, arg) {
415 (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
416 infer_record_update_type_with_hint(
417 unifier,
418 supply,
419 env,
420 adts,
421 known,
422 base.as_ref(),
423 updates,
424 &arg_hint,
425 )
426 }
427 (Some(arg_hint), Expr::Dict(_, kvs))
428 if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
429 {
430 let TypeKind::Record(fields) = arg_hint.as_ref() else {
431 unreachable!("guarded by matches!")
432 };
433 let expected: BTreeMap<_, _> =
434 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
435 let mut seen = BTreeSet::new();
436 let mut preds = Vec::new();
437 for (k, v) in kvs {
438 let expected_ty = expected
439 .get(k)
440 .ok_or_else(|| TypeError::UnknownField {
441 field: k.clone(),
442 typ: Type::record(fields.clone()).to_string(),
443 })?
444 .clone();
445 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
446 unifier.unify(&t1, &expected_ty)?;
447 preds.extend(p1);
448 seen.insert(k.clone());
449 }
450 for key in expected.keys() {
451 if !seen.contains(key.as_ref()) {
452 return Err(TypeError::UnknownField {
453 field: key.clone(),
454 typ: Type::record(fields.clone()).to_string(),
455 });
456 }
457 }
458 let record_ty = Type::record(
459 fields
460 .iter()
461 .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
462 .collect(),
463 );
464 Ok((preds, record_ty))
465 }
466 _ => infer_expr_type(unifier, supply, env, adts, known, arg),
467 }
468}
469
470fn infer_app_arg_typed(
471 unifier: &mut Unifier<'_>,
472 supply: &mut TypeVarSupply,
473 env: &TypeEnv,
474 adts: &BTreeMap<Symbol, AdtDecl>,
475 known: &KnownVariants,
476 arg_hint: Option<Type>,
477 arg: &Expr,
478) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
479 match (arg_hint, arg) {
480 (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
481 infer_record_update_typed_with_hint(
482 unifier,
483 supply,
484 env,
485 adts,
486 known,
487 base.as_ref(),
488 updates,
489 &arg_hint,
490 )
491 }
492 (Some(arg_hint), Expr::Dict(_, kvs))
493 if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
494 {
495 let TypeKind::Record(fields) = arg_hint.as_ref() else {
496 unreachable!("guarded by matches!")
497 };
498 let mut preds = Vec::new();
499 let mut typed_kvs = BTreeMap::new();
500 let expected: BTreeMap<_, _> =
501 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
502 for (k, v) in kvs {
503 let expected_ty = expected
504 .get(k)
505 .ok_or_else(|| TypeError::UnknownField {
506 field: k.clone(),
507 typ: Type::record(fields.clone()).to_string(),
508 })?
509 .clone();
510 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
511 unifier.unify(&t1, &expected_ty)?;
512 preds.extend(p1);
513 typed_kvs.insert(k.clone(), typed_v);
514 }
515 for key in expected.keys() {
516 if !typed_kvs.contains_key(key.as_ref()) {
517 return Err(TypeError::UnknownField {
518 field: key.clone(),
519 typ: Type::record(fields.clone()).to_string(),
520 });
521 }
522 }
523 let record_ty = Type::record(
524 fields
525 .iter()
526 .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
527 .collect(),
528 );
529 let typed = TypedExpr::new(record_ty.clone(), TypedExprKind::Dict(typed_kvs));
530 Ok((preds, record_ty, typed))
531 }
532 _ => infer_expr(unifier, supply, env, adts, known, arg),
533 }
534}
535
536#[allow(clippy::too_many_arguments)]
537fn infer_record_update_type_with_hint(
538 unifier: &mut Unifier<'_>,
539 supply: &mut TypeVarSupply,
540 env: &TypeEnv,
541 adts: &BTreeMap<Symbol, AdtDecl>,
542 known: &KnownVariants,
543 base: &Expr,
544 updates: &BTreeMap<Symbol, Arc<Expr>>,
545 hint_ty: &Type,
546) -> Result<(Vec<Predicate>, Type), TypeError> {
547 let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
548 unifier.unify(&t_base, hint_ty)?;
549 let base_ty = unifier.apply_type(&t_base);
550 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
551 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
552 let (result_ty, fields) = resolve_record_update(
553 unifier,
554 supply,
555 adts,
556 &base_ty,
557 known_variant,
558 &update_fields,
559 )?;
560 let expected: BTreeMap<_, _> = fields.into_iter().collect();
561
562 let mut preds = p_base;
563 for (k, v) in updates {
564 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
565 field: k.clone(),
566 typ: result_ty.to_string(),
567 })?;
568 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
569 unifier.unify(&t1, expected_ty)?;
570 preds.extend(p1);
571 }
572 Ok((preds, result_ty))
573}
574
575#[allow(clippy::too_many_arguments)]
576fn infer_record_update_typed_with_hint(
577 unifier: &mut Unifier<'_>,
578 supply: &mut TypeVarSupply,
579 env: &TypeEnv,
580 adts: &BTreeMap<Symbol, AdtDecl>,
581 known: &KnownVariants,
582 base: &Expr,
583 updates: &BTreeMap<Symbol, Arc<Expr>>,
584 hint_ty: &Type,
585) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
586 let (p_base, t_base, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
587 unifier.unify(&t_base, hint_ty)?;
588 let base_ty = unifier.apply_type(&t_base);
589 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
590 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
591 let (result_ty, fields) = resolve_record_update(
592 unifier,
593 supply,
594 adts,
595 &base_ty,
596 known_variant,
597 &update_fields,
598 )?;
599 let expected: BTreeMap<_, _> = fields.into_iter().collect();
600
601 let mut preds = p_base;
602 let mut typed_updates = BTreeMap::new();
603 for (k, v) in updates {
604 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
605 field: k.clone(),
606 typ: result_ty.to_string(),
607 })?;
608 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
609 unifier.unify(&t1, expected_ty)?;
610 preds.extend(p1);
611 typed_updates.insert(k.clone(), typed_v);
612 }
613
614 let typed = TypedExpr::new(
615 result_ty.clone(),
616 TypedExprKind::RecordUpdate {
617 base: Box::new(typed_base),
618 updates: typed_updates,
619 },
620 );
621 Ok((preds, result_ty, typed))
622}
623
624fn infer_expr_type(
625 unifier: &mut Unifier<'_>,
626 supply: &mut TypeVarSupply,
627 env: &TypeEnv,
628 adts: &BTreeMap<Symbol, AdtDecl>,
629 known: &KnownVariants,
630 expr: &Expr,
631) -> Result<(Vec<Predicate>, Type), TypeError> {
632 let span = *expr.span();
633 let res = unifier.with_infer_depth(span, |unifier| {
634 infer_expr_type_inner(unifier, supply, env, adts, known, expr)
635 });
636 res.map_err(|err| err.with_span(&span))
637}
638
639fn infer_expr_type_inner(
640 unifier: &mut Unifier<'_>,
641 supply: &mut TypeVarSupply,
642 env: &TypeEnv,
643 adts: &BTreeMap<Symbol, AdtDecl>,
644 known: &KnownVariants,
645 expr: &Expr,
646) -> Result<(Vec<Predicate>, Type), TypeError> {
647 unifier.charge_infer_node()?;
648 match expr {
649 Expr::Bool(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Bool))),
650 Expr::Uint(_, _) => {
651 let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
652 Ok((vec![Predicate::new("Integral", lit_ty.clone())], lit_ty))
653 }
654 Expr::Int(_, _) => {
655 let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
656 Ok((
657 vec![
658 Predicate::new("Integral", lit_ty.clone()),
659 Predicate::new("AdditiveGroup", lit_ty.clone()),
660 ],
661 lit_ty,
662 ))
663 }
664 Expr::Float(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::F32))),
665 Expr::String(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::String))),
666 Expr::Uuid(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Uuid))),
667 Expr::DateTime(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::DateTime))),
668 Expr::Hole(_) => {
669 let t = Type::var(supply.fresh(Some(sym("hole"))));
670 Ok((vec![], t))
671 }
672 Expr::Var(var) => {
673 let schemes = env
674 .lookup(&var.name)
675 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
676 if schemes.len() == 1 {
677 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
678 let (preds, t) = instantiate(&scheme, supply);
679 Ok((preds, t))
680 } else {
681 for scheme in schemes {
682 if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
683 return Err(TypeError::AmbiguousOverload(var.name.clone()));
684 }
685 }
686 let t = Type::var(supply.fresh(Some(var.name.clone())));
687 Ok((vec![], t))
688 }
689 }
690 Expr::Lam(..) => {
691 let (params, constraints, body) = collect_lambda_chain(expr);
692 let mut ann_vars = BTreeMap::new();
693 let mut param_tys = Vec::with_capacity(params.len());
694 for (name, ann) in ¶ms {
695 let param_ty = match ann {
696 Some(ann) => type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?,
697 None => Type::var(supply.fresh(Some(name.clone()))),
698 };
699 param_tys.push((name.clone(), param_ty));
700 }
701
702 let mut env1 = env.clone();
703 let mut known_body = known.clone();
704 for (name, param_ty) in ¶m_tys {
705 env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
706 known_body.remove(name);
707 }
708
709 let (mut preds, body_ty) =
710 infer_expr_type(unifier, supply, &env1, adts, &known_body, body)?;
711 let constraint_preds =
712 predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
713 preds.extend(constraint_preds);
714
715 let mut fun_ty = unifier.apply_type(&body_ty);
716 for (_, param_ty) in param_tys.iter().rev() {
717 fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
718 }
719 Ok((preds, fun_ty))
720 }
721 Expr::App(..) => {
722 let (head, args) = collect_app_chain(expr);
723 let (mut preds, mut func_ty) =
724 infer_expr_type(unifier, supply, env, adts, known, head)?;
725 let mut overload_name = None;
726 let mut overload_candidates = if let Expr::Var(var) = head {
727 if let Some(schemes) = env.lookup(&var.name) {
728 if schemes.len() <= 1 {
729 None
730 } else {
731 let mut candidates = Vec::new();
732 for scheme in schemes {
733 if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
734 return Err(TypeError::AmbiguousOverload(var.name.clone()));
735 }
736 let scheme = apply_scheme_with_unifier(scheme, unifier);
737 let (p, typ) = instantiate(&scheme, supply);
738 if !p.is_empty() {
739 return Err(TypeError::AmbiguousOverload(var.name.clone()));
740 }
741 candidates.push(typ);
742 }
743 overload_name = Some(var.name.clone());
744 Some(candidates)
745 }
746 } else {
747 None
748 }
749 } else {
750 None
751 };
752 for arg in args {
753 let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
754 TypeKind::Fun(arg, _) => Some(arg.clone()),
755 _ => None,
756 };
757 let (p_arg, arg_ty) =
758 infer_app_arg_type(unifier, supply, env, adts, known, arg_hint, arg)?;
759 let arg_ty = unifier.apply_type(&arg_ty);
760 if let Some(candidates) = overload_candidates.take() {
761 let candidates = candidates
762 .into_iter()
763 .map(|t| unifier.apply_type(&t))
764 .collect::<Vec<_>>();
765 let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
766 if narrowed.is_empty()
767 && let Some(name) = &overload_name
768 {
769 return Err(TypeError::AmbiguousOverload(name.clone()));
770 }
771 overload_candidates = Some(narrowed);
772 }
773 let res_ty = match overload_candidates.as_ref() {
774 Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
775 _ => Type::var(supply.fresh(Some("r".into()))),
776 };
777 unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
778 preds.extend(p_arg);
779 func_ty = match overload_candidates.as_ref() {
780 Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
781 _ => unifier.apply_type(&res_ty),
782 };
783 }
784 Ok((preds, func_ty))
785 }
786 Expr::Project(_, base, field) => {
787 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, base)?;
788 let base_ty = unifier.apply_type(&t1);
789 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
790 let field_ty =
791 resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
792 Ok((p1, field_ty))
793 }
794 Expr::RecordUpdate(_, base, updates) => {
795 let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
796 let base_ty = unifier.apply_type(&t_base);
797 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
798 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
799 let (result_ty, fields) = resolve_record_update(
800 unifier,
801 supply,
802 adts,
803 &base_ty,
804 known_variant,
805 &update_fields,
806 )?;
807 let expected: BTreeMap<_, _> = fields.into_iter().collect();
808
809 let mut preds = p_base;
810 for (k, v) in updates {
811 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
812 field: k.clone(),
813 typ: result_ty.to_string(),
814 })?;
815 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
816 unifier.unify(&t1, expected_ty)?;
817 preds.extend(p1);
818 }
819 Ok((preds, result_ty))
820 }
821 Expr::Let(..) => {
822 let mut bindings = Vec::new();
823 let mut cur = expr;
824 while let Expr::Let(_, v, ann, d, b) = cur {
825 bindings.push((v.clone(), ann.clone(), d.clone()));
826 cur = b.as_ref();
827 }
828
829 let mut env_cur = env.clone();
830 let mut known_cur = known.clone();
831 for (v, ann, d) in bindings {
832 let (p1, t1) = if let Some(ref ann_expr) = ann {
833 let mut ann_vars = BTreeMap::new();
834 let ann_ty =
835 type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
836 match d.as_ref() {
837 Expr::RecordUpdate(_, base, updates) => infer_record_update_type_with_hint(
838 unifier,
839 supply,
840 &env_cur,
841 adts,
842 &known_cur,
843 base.as_ref(),
844 updates,
845 &ann_ty,
846 )?,
847 _ => {
848 let (p1, t1) =
849 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?;
850 unifier.unify(&t1, &ann_ty)?;
851 (p1, t1)
852 }
853 }
854 } else {
855 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?
856 };
857 let def_ty = unifier.apply_type(&t1);
858 let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
859 monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
860 } else {
861 let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
862 reject_ambiguous_scheme(&scheme)?;
863 scheme
864 };
865 env_cur.extend(v.name.clone(), scheme);
866 if let Some(known_variant) =
867 known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
868 {
869 known_cur.insert(
870 v.name.clone(),
871 KnownVariant {
872 adt: known_variant.adt,
873 variant: known_variant.variant,
874 },
875 );
876 } else {
877 known_cur.remove(&v.name);
878 }
879 }
880
881 let (p_body, t_body) =
882 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, cur)?;
883 Ok((p_body, t_body))
884 }
885 Expr::LetRec(_, bindings, body) => {
886 let mut env_seed = env.clone();
887 let mut known_seed = known.clone();
888 let mut binding_tys = BTreeMap::new();
889 for (var, _ann, _def) in bindings {
890 let tv = Type::var(supply.fresh(Some(var.name.clone())));
891 env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
892 known_seed.remove(&var.name);
893 binding_tys.insert(var.name.clone(), tv);
894 }
895
896 let mut inferred = Vec::with_capacity(bindings.len());
897 for (var, ann, def) in bindings {
898 let (preds, def_ty) =
899 infer_expr_type(unifier, supply, &env_seed, adts, &known_seed, def)?;
900 if let Some(ann) = ann {
901 let mut ann_vars = BTreeMap::new();
902 let ann_ty = type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
903 unifier.unify(&def_ty, &ann_ty)?;
904 }
905 let binding_ty = binding_tys
906 .get(&var.name)
907 .cloned()
908 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
909 unifier.unify(&binding_ty, &def_ty)?;
910 let resolved_ty = unifier.apply_type(&binding_ty);
911
912 if let Some(known_variant) =
913 known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
914 {
915 known_seed.insert(
916 var.name.clone(),
917 KnownVariant {
918 adt: known_variant.adt,
919 variant: known_variant.variant,
920 },
921 );
922 } else {
923 known_seed.remove(&var.name);
924 }
925 inferred.push((var.name.clone(), preds, resolved_ty));
926 }
927
928 let mut env_body = env.clone();
929 for (name, preds, def_ty) in inferred {
930 let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
931 reject_ambiguous_scheme(&scheme)?;
932 env_body.extend(name, scheme);
933 }
934
935 let (p_body, t_body) =
936 infer_expr_type(unifier, supply, &env_body, adts, &known_seed, body)?;
937 Ok((p_body, t_body))
938 }
939 Expr::Ite(_, cond, then_expr, else_expr) => {
940 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, cond)?;
941 unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
942 let (p2, t2) = infer_expr_type(unifier, supply, env, adts, known, then_expr)?;
943 let (p3, t3) = infer_expr_type(unifier, supply, env, adts, known, else_expr)?;
944 unifier.unify(&t2, &t3)?;
945 let out_ty = unifier.apply_type(&t2);
946 let mut preds = p1;
947 preds.extend(p2);
948 preds.extend(p3);
949 Ok((preds, out_ty))
950 }
951 Expr::Tuple(_, elems) => {
952 let mut preds = Vec::new();
953 let mut types = Vec::new();
954 for elem in elems {
955 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
956 preds.extend(p1);
957 types.push(unifier.apply_type(&t1));
958 }
959 let tuple_ty = Type::tuple(types);
960 Ok((preds, tuple_ty))
961 }
962 Expr::List(_, elems) => {
963 let elem_tv = Type::var(supply.fresh(Some("a".into())));
964 let mut preds = Vec::new();
965 for elem in elems {
966 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
967 unifier.unify(&t1, &elem_tv)?;
968 preds.extend(p1);
969 }
970 let list_ty = Type::app(
971 Type::builtin(BuiltinTypeId::List),
972 unifier.apply_type(&elem_tv),
973 );
974 Ok((preds, list_ty))
975 }
976 Expr::Dict(_, kvs) => {
977 let elem_tv = Type::var(supply.fresh(Some("v".into())));
978 let mut preds = Vec::new();
979 for v in kvs.values() {
980 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
981 unifier.unify(&t1, &elem_tv)?;
982 preds.extend(p1);
983 }
984 let dict_ty = Type::app(
985 Type::builtin(BuiltinTypeId::Dict),
986 unifier.apply_type(&elem_tv),
987 );
988 Ok((preds, dict_ty))
989 }
990 Expr::Match(_, scrutinee, arms) => {
991 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, scrutinee.as_ref())?;
992 let mut preds = p1;
993 let res_ty = Type::var(supply.fresh(Some("match".into())));
994 let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
995
996 for (pat, expr) in arms {
997 let scrutinee_ty = unifier.apply_type(&t1);
998 let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
999 preds.extend(p_pat);
1000
1001 let mut env_arm = env.clone();
1002 for (name, ty) in binds {
1003 env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1004 }
1005 let mut known_arm = known.clone();
1006 if let Expr::Var(var) = scrutinee.as_ref() {
1007 match pat {
1008 Pattern::Named(_, name, _) => {
1009 let name_sym = name.to_dotted_symbol();
1010 if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1011 known_arm.insert(
1012 var.name.clone(),
1013 KnownVariant {
1014 adt: adt.name.clone(),
1015 variant: name_sym,
1016 },
1017 );
1018 } else {
1019 known_arm.remove(&var.name);
1020 }
1021 }
1022 _ => {
1023 known_arm.remove(&var.name);
1024 }
1025 }
1026 }
1027 let (p_expr, t_expr) =
1028 infer_expr_type(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1029 unifier.unify(&res_ty, &t_expr)?;
1030 preds.extend(p_expr);
1031 }
1032
1033 let scrutinee_ty = unifier.apply_type(&t1);
1034 check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1035 let out_ty = unifier.apply_type(&res_ty);
1036 Ok((preds, out_ty))
1037 }
1038 Expr::Ann(_, expr, ann) => {
1039 let ann_ty = type_from_annotation_expr(adts, ann)?;
1040 match expr.as_ref() {
1041 Expr::RecordUpdate(_, base, updates) => {
1042 let (preds, out_ty) = infer_record_update_type_with_hint(
1043 unifier,
1044 supply,
1045 env,
1046 adts,
1047 known,
1048 base.as_ref(),
1049 updates,
1050 &ann_ty,
1051 )?;
1052 Ok((preds, out_ty))
1053 }
1054 _ => {
1055 let (preds, expr_ty) =
1056 infer_expr_type(unifier, supply, env, adts, known, expr)?;
1057 unifier.unify(&expr_ty, &ann_ty)?;
1058 let out_ty = unifier.apply_type(&ann_ty);
1059 Ok((preds, out_ty))
1060 }
1061 }
1062 }
1063 }
1064}
1065
1066fn infer_expr(
1067 unifier: &mut Unifier<'_>,
1068 supply: &mut TypeVarSupply,
1069 env: &TypeEnv,
1070 adts: &BTreeMap<Symbol, AdtDecl>,
1071 known: &KnownVariants,
1072 expr: &Expr,
1073) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
1074 let span = *expr.span();
1075 let res = unifier.with_infer_depth(span, |unifier| {
1076 (|| {
1077 unifier.charge_infer_node()?;
1078 match expr {
1079 Expr::Bool(_, v) => {
1080 let t = Type::builtin(BuiltinTypeId::Bool);
1081 Ok((
1082 vec![],
1083 t.clone(),
1084 TypedExpr::new(t, TypedExprKind::Bool(*v)),
1085 ))
1086 }
1087 Expr::Uint(_, v) => {
1088 let t = Type::var(supply.fresh(Some(sym("n"))));
1089 Ok((
1090 vec![Predicate::new("Integral", t.clone())],
1091 t.clone(),
1092 TypedExpr::new(t, TypedExprKind::Uint(*v)),
1093 ))
1094 }
1095 Expr::Int(_, v) => {
1096 let t = Type::var(supply.fresh(Some(sym("n"))));
1097 Ok((
1098 vec![
1099 Predicate::new("Integral", t.clone()),
1100 Predicate::new("AdditiveGroup", t.clone()),
1101 ],
1102 t.clone(),
1103 TypedExpr::new(t, TypedExprKind::Int(*v)),
1104 ))
1105 }
1106 Expr::Float(_, v) => {
1107 let t = Type::builtin(BuiltinTypeId::F32);
1108 Ok((
1109 vec![],
1110 t.clone(),
1111 TypedExpr::new(t, TypedExprKind::Float(*v)),
1112 ))
1113 }
1114 Expr::String(_, v) => {
1115 let t = Type::builtin(BuiltinTypeId::String);
1116 Ok((
1117 vec![],
1118 t.clone(),
1119 TypedExpr::new(t, TypedExprKind::String(v.clone())),
1120 ))
1121 }
1122 Expr::Uuid(_, v) => {
1123 let t = Type::builtin(BuiltinTypeId::Uuid);
1124 Ok((
1125 vec![],
1126 t.clone(),
1127 TypedExpr::new(t, TypedExprKind::Uuid(*v)),
1128 ))
1129 }
1130 Expr::DateTime(_, v) => {
1131 let t = Type::builtin(BuiltinTypeId::DateTime);
1132 Ok((
1133 vec![],
1134 t.clone(),
1135 TypedExpr::new(t, TypedExprKind::DateTime(*v)),
1136 ))
1137 }
1138 Expr::Hole(_) => {
1139 let t = Type::var(supply.fresh(Some(sym("hole"))));
1140 Ok((vec![], t.clone(), TypedExpr::new(t, TypedExprKind::Hole)))
1141 }
1142 Expr::Var(var) => {
1143 let schemes = env
1144 .lookup(&var.name)
1145 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1146 if schemes.len() == 1 {
1147 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1148 let (preds, t) = instantiate(&scheme, supply);
1149 let typed = TypedExpr::new(
1150 t.clone(),
1151 TypedExprKind::Var {
1152 name: var.name.clone(),
1153 overloads: vec![],
1154 },
1155 );
1156 Ok((preds, t, typed))
1157 } else {
1158 let mut overloads = Vec::new();
1159 for scheme in schemes {
1160 if !scheme.preds.is_empty() {
1161 return Err(TypeError::AmbiguousOverload(var.name.clone()));
1162 }
1163
1164 let scheme = apply_scheme_with_unifier(scheme, unifier);
1165 let (preds, typ) = instantiate(&scheme, supply);
1166 if !preds.is_empty() {
1167 return Err(TypeError::AmbiguousOverload(var.name.clone()));
1168 }
1169 overloads.push(typ);
1170 }
1171 let t = Type::var(supply.fresh(Some(var.name.clone())));
1172 let typed = TypedExpr::new(
1173 t.clone(),
1174 TypedExprKind::Var {
1175 name: var.name.clone(),
1176 overloads,
1177 },
1178 );
1179 Ok((vec![], t, typed))
1180 }
1181 }
1182 Expr::Lam(..) => {
1183 let (params, constraints, body) = collect_lambda_chain(expr);
1184 let mut ann_vars = BTreeMap::new();
1185 let mut param_tys = Vec::with_capacity(params.len());
1186 for (name, ann) in ¶ms {
1187 let param_ty = match ann {
1188 Some(ann) => {
1189 type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?
1190 }
1191 None => Type::var(supply.fresh(Some(name.clone()))),
1192 };
1193 param_tys.push((name.clone(), param_ty));
1194 }
1195
1196 let mut env1 = env.clone();
1197 let mut known_body = known.clone();
1198 for (name, param_ty) in ¶m_tys {
1199 env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
1200 known_body.remove(name);
1201 }
1202
1203 let (mut preds, body_ty, typed_body) =
1204 infer_expr(unifier, supply, &env1, adts, &known_body, body)?;
1205 let constraint_preds =
1206 predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
1207 preds.extend(constraint_preds);
1208
1209 let mut typed = typed_body;
1210 let mut fun_ty = unifier.apply_type(&body_ty);
1211 for (name, param_ty) in param_tys.iter().rev() {
1212 fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
1213 typed = TypedExpr::new(
1214 fun_ty.clone(),
1215 TypedExprKind::Lam {
1216 param: name.clone(),
1217 body: Box::new(typed),
1218 },
1219 );
1220 }
1221
1222 Ok((preds, fun_ty, typed))
1223 }
1224 Expr::App(..) => {
1225 let (head, args) = collect_app_chain(expr);
1226 let (mut preds, mut func_ty, mut typed) =
1227 infer_expr(unifier, supply, env, adts, known, head)?;
1228 let mut overload_name = None;
1229 let mut overload_candidates = match &typed.kind {
1230 TypedExprKind::Var { name, overloads } if !overloads.is_empty() => {
1231 overload_name = Some(name.clone());
1232 Some(overloads.clone())
1233 }
1234 _ => None,
1235 };
1236 for arg in args {
1237 let expected_arg = match unifier.apply_type(&func_ty).as_ref() {
1238 TypeKind::Fun(arg, _) => Some(arg.clone()),
1239 _ => None,
1240 };
1241 let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
1242 TypeKind::Fun(arg, _) => Some(arg.clone()),
1243 _ => None,
1244 };
1245 let (p_arg, arg_ty, typed_arg) =
1246 infer_app_arg_typed(unifier, supply, env, adts, known, arg_hint, arg)?;
1247 let mut arg_ty = unifier.apply_type(&arg_ty);
1248 let mut typed_arg = typed_arg;
1249
1250 if let Some(expected_arg) = expected_arg {
1251 let expected_arg = unifier.apply_type(&expected_arg);
1252 if let (Some(expected_elem), Some(arg_elem)) = (
1253 unary_app_arg(&expected_arg, "Array"),
1254 unary_app_arg(&arg_ty, "List"),
1255 ) {
1256 unifier.unify(&expected_elem, &arg_elem)?;
1257 let elem_ty = unifier.apply_type(&expected_elem);
1258 let list_ty = Type::list(elem_ty.clone());
1259 let array_ty = Type::array(elem_ty);
1260 let coercion_ty = Type::fun(list_ty, array_ty.clone());
1261 let coercion_fn = TypedExpr::new(
1262 coercion_ty,
1263 TypedExprKind::Var {
1264 name: sym("prim_array_from_list"),
1265 overloads: vec![],
1266 },
1267 );
1268 typed_arg = TypedExpr::new(
1269 array_ty.clone(),
1270 TypedExprKind::App(Box::new(coercion_fn), Box::new(typed_arg)),
1271 );
1272 arg_ty = array_ty;
1273 }
1274 }
1275 if let Some(candidates) = overload_candidates.take() {
1276 let candidates = candidates
1277 .into_iter()
1278 .map(|t| unifier.apply_type(&t))
1279 .collect::<Vec<_>>();
1280 let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
1281 if narrowed.is_empty()
1282 && let Some(name) = &overload_name
1283 {
1284 return Err(TypeError::AmbiguousOverload(name.clone()));
1285 }
1286 overload_candidates = Some(narrowed);
1287 }
1288 let res_ty = match overload_candidates.as_ref() {
1289 Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
1290 _ => Type::var(supply.fresh(Some("r".into()))),
1291 };
1292 unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
1293 let result_ty = match overload_candidates.as_ref() {
1294 Some(candidates) if candidates.len() == 1 => {
1295 unifier.apply_type(&candidates[0])
1296 }
1297 _ => unifier.apply_type(&res_ty),
1298 };
1299 preds.extend(p_arg);
1300 typed = TypedExpr::new(
1301 result_ty.clone(),
1302 TypedExprKind::App(Box::new(typed), Box::new(typed_arg)),
1303 );
1304 func_ty = result_ty;
1305 }
1306 Ok((preds, func_ty, typed))
1307 }
1308 Expr::Project(_, base, field) => {
1309 let (p1, t1, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
1310 let base_ty = unifier.apply_type(&t1);
1311 let known_variant =
1312 known_variant_from_expr_with_known(base, &base_ty, adts, known);
1313 let field_ty =
1314 resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
1315 let typed = TypedExpr::new(
1316 field_ty.clone(),
1317 TypedExprKind::Project {
1318 expr: Box::new(typed_base),
1319 field: field.clone(),
1320 },
1321 );
1322 Ok((p1, field_ty, typed))
1323 }
1324 Expr::RecordUpdate(_, base, updates) => {
1325 let (p_base, t_base, typed_base) =
1326 infer_expr(unifier, supply, env, adts, known, base)?;
1327 let base_ty = unifier.apply_type(&t_base);
1328 let known_variant =
1329 known_variant_from_expr_with_known(base, &base_ty, adts, known);
1330 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
1331 let (result_ty, fields) = resolve_record_update(
1332 unifier,
1333 supply,
1334 adts,
1335 &base_ty,
1336 known_variant,
1337 &update_fields,
1338 )?;
1339 let expected: BTreeMap<_, _> = fields.into_iter().collect();
1340
1341 let mut preds = p_base;
1342 let mut typed_updates = BTreeMap::new();
1343 for (k, v) in updates {
1344 let expected_ty =
1345 expected.get(k).ok_or_else(|| TypeError::UnknownField {
1346 field: k.clone(),
1347 typ: result_ty.to_string(),
1348 })?;
1349 let (p1, t1, typed_v) =
1350 infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
1351 unifier.unify(&t1, expected_ty)?;
1352 preds.extend(p1);
1353 typed_updates.insert(k.clone(), typed_v);
1354 }
1355 let typed = TypedExpr::new(
1356 result_ty.clone(),
1357 TypedExprKind::RecordUpdate {
1358 base: Box::new(typed_base),
1359 updates: typed_updates,
1360 },
1361 );
1362 Ok((preds, result_ty, typed))
1363 }
1364 Expr::Let(..) => {
1365 let mut bindings = Vec::new();
1366 let mut cur = expr;
1367 while let Expr::Let(_, v, ann, d, b) = cur {
1368 bindings.push((v.clone(), ann.clone(), d.clone()));
1369 cur = b.as_ref();
1370 }
1371
1372 let mut env_cur = env.clone();
1373 let mut known_cur = known.clone();
1374 let mut typed_defs = Vec::new();
1375 for (v, ann, d) in bindings {
1376 let (p1, t1, typed_def) = if let Some(ref ann_expr) = ann {
1377 let mut ann_vars = BTreeMap::new();
1378 let ann_ty = type_from_annotation_expr_vars(
1379 adts,
1380 ann_expr,
1381 &mut ann_vars,
1382 supply,
1383 )?;
1384 match d.as_ref() {
1385 Expr::RecordUpdate(_, base, updates) => {
1386 infer_record_update_typed_with_hint(
1387 unifier,
1388 supply,
1389 &env_cur,
1390 adts,
1391 &known_cur,
1392 base.as_ref(),
1393 updates,
1394 &ann_ty,
1395 )?
1396 }
1397 _ => {
1398 let (p1, t1, typed_def) = infer_expr(
1399 unifier, supply, &env_cur, adts, &known_cur, &d,
1400 )?;
1401 unifier.unify(&t1, &ann_ty)?;
1402 (p1, t1, typed_def)
1403 }
1404 }
1405 } else {
1406 infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?
1407 };
1408 let def_ty = unifier.apply_type(&t1);
1409 let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
1410 monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
1411 } else {
1412 let scheme =
1413 generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
1414 reject_ambiguous_scheme(&scheme)?;
1415 scheme
1416 };
1417 env_cur.extend(v.name.clone(), scheme);
1418 if let Some(known_variant) =
1419 known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
1420 {
1421 known_cur.insert(
1422 v.name.clone(),
1423 KnownVariant {
1424 adt: known_variant.adt,
1425 variant: known_variant.variant,
1426 },
1427 );
1428 } else {
1429 known_cur.remove(&v.name);
1430 }
1431 typed_defs.push((v.name.clone(), typed_def));
1432 }
1433
1434 let (p_body, t_body, typed_body) =
1435 infer_expr(unifier, supply, &env_cur, adts, &known_cur, cur)?;
1436
1437 let mut typed = typed_body;
1438 for (name, def) in typed_defs.into_iter().rev() {
1439 typed = TypedExpr::new(
1440 t_body.clone(),
1441 TypedExprKind::Let {
1442 name,
1443 def: Box::new(def),
1444 body: Box::new(typed),
1445 },
1446 );
1447 }
1448 Ok((p_body, t_body, typed))
1449 }
1450 Expr::LetRec(_, bindings, body) => {
1451 let mut env_seed = env.clone();
1452 let mut known_seed = known.clone();
1453 let mut binding_tys = BTreeMap::new();
1454 for (var, _ann, _def) in bindings {
1455 let tv = Type::var(supply.fresh(Some(var.name.clone())));
1456 env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
1457 known_seed.remove(&var.name);
1458 binding_tys.insert(var.name.clone(), tv);
1459 }
1460
1461 let mut inferred_defs = Vec::with_capacity(bindings.len());
1462 for (var, ann, def) in bindings {
1463 let (preds, def_ty, typed_def) =
1464 infer_expr(unifier, supply, &env_seed, adts, &known_seed, def)?;
1465 if let Some(ann) = ann {
1466 let mut ann_vars = BTreeMap::new();
1467 let ann_ty =
1468 type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
1469 unifier.unify(&def_ty, &ann_ty)?;
1470 }
1471 let binding_ty = binding_tys
1472 .get(&var.name)
1473 .cloned()
1474 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1475 unifier.unify(&binding_ty, &def_ty)?;
1476 let resolved_ty = unifier.apply_type(&binding_ty);
1477
1478 if let Some(known_variant) =
1479 known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
1480 {
1481 known_seed.insert(
1482 var.name.clone(),
1483 KnownVariant {
1484 adt: known_variant.adt,
1485 variant: known_variant.variant,
1486 },
1487 );
1488 } else {
1489 known_seed.remove(&var.name);
1490 }
1491 inferred_defs.push((var.name.clone(), preds, resolved_ty, typed_def));
1492 }
1493
1494 let mut env_body = env.clone();
1495 let mut typed_bindings = Vec::with_capacity(inferred_defs.len());
1496 for (name, preds, def_ty, typed_def) in inferred_defs {
1497 let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
1498 reject_ambiguous_scheme(&scheme)?;
1499 env_body.extend(name.clone(), scheme);
1500 typed_bindings.push((name, typed_def));
1501 }
1502
1503 let (p_body, t_body, typed_body) =
1504 infer_expr(unifier, supply, &env_body, adts, &known_seed, body)?;
1505 let typed = TypedExpr::new(
1506 t_body.clone(),
1507 TypedExprKind::LetRec {
1508 bindings: typed_bindings,
1509 body: Box::new(typed_body),
1510 },
1511 );
1512 Ok((p_body, t_body, typed))
1513 }
1514 Expr::Ite(_, cond, then_expr, else_expr) => {
1515 let (p1, t1, typed_cond) = infer_expr(unifier, supply, env, adts, known, cond)?;
1516 unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
1517 let (p2, t2, typed_then) =
1518 infer_expr(unifier, supply, env, adts, known, then_expr)?;
1519 let (p3, t3, typed_else) =
1520 infer_expr(unifier, supply, env, adts, known, else_expr)?;
1521 unifier.unify(&t2, &t3)?;
1522 let out_ty = unifier.apply_type(&t2);
1523 let mut preds = p1;
1524 preds.extend(p2);
1525 preds.extend(p3);
1526 let typed = TypedExpr::new(
1527 out_ty.clone(),
1528 TypedExprKind::Ite {
1529 cond: Box::new(typed_cond),
1530 then_expr: Box::new(typed_then),
1531 else_expr: Box::new(typed_else),
1532 },
1533 );
1534 Ok((preds, out_ty, typed))
1535 }
1536 Expr::Tuple(_, elems) => {
1537 let mut preds = Vec::new();
1538 let mut types = Vec::new();
1539 let mut typed_elems = Vec::new();
1540 for elem in elems {
1541 let (p1, t1, typed_elem) =
1542 infer_expr(unifier, supply, env, adts, known, elem)?;
1543 preds.extend(p1);
1544 types.push(unifier.apply_type(&t1));
1545 typed_elems.push(typed_elem);
1546 }
1547 let tuple_ty = Type::tuple(types);
1548 let typed = TypedExpr::new(tuple_ty.clone(), TypedExprKind::Tuple(typed_elems));
1549 Ok((preds, tuple_ty, typed))
1550 }
1551 Expr::List(_, elems) => {
1552 let elem_tv = Type::var(supply.fresh(Some("a".into())));
1553 let mut preds = Vec::new();
1554 let mut typed_elems = Vec::new();
1555 for elem in elems {
1556 let (p1, t1, typed_elem) =
1557 infer_expr(unifier, supply, env, adts, known, elem)?;
1558 unifier.unify(&t1, &elem_tv)?;
1559 preds.extend(p1);
1560 typed_elems.push(typed_elem);
1561 }
1562 let list_ty = Type::app(
1563 Type::builtin(BuiltinTypeId::List),
1564 unifier.apply_type(&elem_tv),
1565 );
1566 let typed = TypedExpr::new(list_ty.clone(), TypedExprKind::List(typed_elems));
1567 Ok((preds, list_ty, typed))
1568 }
1569 Expr::Dict(_, kvs) => {
1570 let elem_tv = Type::var(supply.fresh(Some("v".into())));
1571 let mut preds = Vec::new();
1572 let mut typed_kvs = BTreeMap::new();
1573 for (k, v) in kvs {
1574 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v)?;
1575 unifier.unify(&t1, &elem_tv)?;
1576 preds.extend(p1);
1577 typed_kvs.insert(k.clone(), typed_v);
1578 }
1579 let dict_ty = Type::app(
1580 Type::builtin(BuiltinTypeId::Dict),
1581 unifier.apply_type(&elem_tv),
1582 );
1583 let typed = TypedExpr::new(dict_ty.clone(), TypedExprKind::Dict(typed_kvs));
1584 Ok((preds, dict_ty, typed))
1585 }
1586 Expr::Match(_, scrutinee, arms) => {
1587 let (p1, t1, typed_scrutinee) =
1588 infer_expr(unifier, supply, env, adts, known, scrutinee)?;
1589 let mut preds = p1;
1590 let mut typed_arms = Vec::new();
1591 let res_ty = Type::var(supply.fresh(Some("match".into())));
1592 let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
1593
1594 for (pat, expr) in arms {
1595 let scrutinee_ty = unifier.apply_type(&t1);
1596 let (p_pat, binds) =
1597 infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
1598 preds.extend(p_pat);
1599
1600 let mut env_arm = env.clone();
1601 for (name, ty) in binds {
1602 env_arm
1603 .extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1604 }
1605 let mut known_arm = known.clone();
1606 if let Expr::Var(var) = scrutinee.as_ref() {
1607 match pat {
1608 Pattern::Named(_, name, _) => {
1609 let name_sym = name.to_dotted_symbol();
1610 if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1611 known_arm.insert(
1612 var.name.clone(),
1613 KnownVariant {
1614 adt: adt.name.clone(),
1615 variant: name_sym,
1616 },
1617 );
1618 } else {
1619 known_arm.remove(&var.name);
1620 }
1621 }
1622 _ => {
1623 known_arm.remove(&var.name);
1624 }
1625 }
1626 }
1627 let (p_expr, t_expr, typed_expr) =
1628 infer_expr(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1629 unifier.unify(&res_ty, &t_expr)?;
1630 preds.extend(p_expr);
1631 typed_arms.push((pat.clone(), typed_expr));
1632 }
1633
1634 let scrutinee_ty = unifier.apply_type(&t1);
1635 check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1636 let out_ty = unifier.apply_type(&res_ty);
1637 let typed = TypedExpr::new(
1638 out_ty.clone(),
1639 TypedExprKind::Match {
1640 scrutinee: Box::new(typed_scrutinee),
1641 arms: typed_arms,
1642 },
1643 );
1644 Ok((preds, out_ty, typed))
1645 }
1646 Expr::Ann(_, expr, ann) => {
1647 let ann_ty = type_from_annotation_expr(adts, ann)?;
1648 match expr.as_ref() {
1649 Expr::RecordUpdate(_, base, updates) => {
1650 infer_record_update_typed_with_hint(
1651 unifier,
1652 supply,
1653 env,
1654 adts,
1655 known,
1656 base.as_ref(),
1657 updates,
1658 &ann_ty,
1659 )
1660 }
1661 _ => {
1662 let (preds, expr_ty, typed_expr) =
1663 infer_expr(unifier, supply, env, adts, known, expr)?;
1664 unifier.unify(&expr_ty, &ann_ty)?;
1665 let out_ty = unifier.apply_type(&ann_ty);
1666 Ok((preds, out_ty, typed_expr))
1667 }
1668 }
1669 }
1670 }
1671 })()
1672 });
1673 res.map_err(|err| err.with_span(&span))
1674}
1675
1676fn ctor_lookup<'a>(
1677 adts: &'a BTreeMap<Symbol, AdtDecl>,
1678 name: &Symbol,
1679) -> Option<(&'a AdtDecl, &'a AdtVariant)> {
1680 let mut found = None;
1681 for adt in adts.values() {
1682 if let Some(variant) = adt.variants.iter().find(|v| &v.name == name) {
1683 if found.is_some() {
1684 return None;
1685 }
1686 found = Some((adt, variant));
1687 }
1688 }
1689 found
1690}
1691
1692fn record_fields(variant: &AdtVariant) -> Option<&[(Symbol, Type)]> {
1693 if variant.args.len() != 1 {
1694 return None;
1695 }
1696 match variant.args[0].as_ref() {
1697 TypeKind::Record(fields) => Some(fields),
1698 _ => None,
1699 }
1700}
1701
1702fn instantiate_variant_fields(
1703 adt: &AdtDecl,
1704 variant: &AdtVariant,
1705 supply: &mut TypeVarSupply,
1706) -> Option<(Type, Vec<(Symbol, Type)>)> {
1707 let fields = record_fields(variant)?;
1708 let mut subst = Subst::new_sync();
1709 for param in &adt.params {
1710 let fresh = Type::var(supply.fresh(param.var.name.clone()));
1711 subst = subst.insert(param.var.id, fresh);
1712 }
1713 let result_ty = adt.result_type().apply(&subst);
1714 let fields = fields
1715 .iter()
1716 .map(|(name, ty)| (name.clone(), ty.apply(&subst)))
1717 .collect();
1718 Some((result_ty, fields))
1719}
1720
1721fn known_variant_from_expr(
1722 expr: &Expr,
1723 expr_ty: &Type,
1724 adts: &BTreeMap<Symbol, AdtDecl>,
1725) -> Option<KnownVariant> {
1726 let mut expr = expr;
1727 while let Expr::Ann(_, inner, _) = expr {
1728 expr = inner.as_ref();
1729 }
1730 if matches!(expr_ty.as_ref(), TypeKind::Fun(..)) {
1731 return None;
1732 }
1733 let ctor = match expr {
1734 Expr::App(_, f, _) => match f.as_ref() {
1735 Expr::Var(var) => var.name.clone(),
1736 _ => return None,
1737 },
1738 _ => return None,
1739 };
1740 let (adt, variant) = ctor_lookup(adts, &ctor)?;
1741 record_fields(variant)?;
1742 Some(KnownVariant {
1743 adt: adt.name.clone(),
1744 variant: variant.name.clone(),
1745 })
1746}
1747
1748fn known_variant_from_expr_with_known(
1749 expr: &Expr,
1750 expr_ty: &Type,
1751 adts: &BTreeMap<Symbol, AdtDecl>,
1752 known: &KnownVariants,
1753) -> Option<KnownVariant> {
1754 let mut expr = expr;
1755 while let Expr::Ann(_, inner, _) = expr {
1756 expr = inner.as_ref();
1757 }
1758 match expr {
1759 Expr::Var(var) => known.get(&var.name).cloned(),
1760 Expr::RecordUpdate(_, base, _) => {
1761 known_variant_from_expr_with_known(base.as_ref(), expr_ty, adts, known)
1762 }
1763 _ => known_variant_from_expr(expr, expr_ty, adts),
1764 }
1765}
1766
1767fn select_record_variant<'a, F>(
1768 adts: &'a BTreeMap<Symbol, AdtDecl>,
1769 base_ty: &Type,
1770 known_variant: Option<KnownVariant>,
1771 field_for_errors: &Symbol,
1772 matches_fields: F,
1773) -> Result<(&'a AdtDecl, &'a AdtVariant), TypeError>
1774where
1775 F: Fn(&[(Symbol, Type)]) -> bool,
1776{
1777 if let Some(info) = known_variant {
1778 let adt = adts
1779 .get(&info.adt)
1780 .ok_or_else(|| TypeError::UnknownTypeName(info.adt.clone()))?;
1781 let variant = adt
1782 .variants
1783 .iter()
1784 .find(|v| v.name == info.variant)
1785 .ok_or_else(|| TypeError::UnknownField {
1786 field: field_for_errors.clone(),
1787 typ: base_ty.to_string(),
1788 })?;
1789 return Ok((adt, variant));
1790 }
1791
1792 if let Some(adt_name) = type_head_name(base_ty) {
1793 let adt = adts.get(adt_name).ok_or_else(|| TypeError::UnknownField {
1794 field: field_for_errors.clone(),
1795 typ: base_ty.to_string(),
1796 })?;
1797 if adt.variants.len() == 1 {
1798 return Ok((adt, &adt.variants[0]));
1799 }
1800 return Err(TypeError::FieldNotKnown {
1801 field: field_for_errors.clone(),
1802 typ: base_ty.to_string(),
1803 });
1804 }
1805
1806 if matches!(base_ty.as_ref(), TypeKind::Var(_)) {
1807 let mut candidates = Vec::new();
1808 for adt in adts.values() {
1809 if adt.variants.len() != 1 {
1810 continue;
1811 }
1812 let variant = &adt.variants[0];
1813 let Some(fields) = record_fields(variant) else {
1814 continue;
1815 };
1816 if matches_fields(fields) {
1817 candidates.push((adt, variant));
1818 }
1819 }
1820 if candidates.len() == 1 {
1821 return Ok(candidates.remove(0));
1822 }
1823 if candidates.is_empty() {
1824 return Err(TypeError::UnknownField {
1825 field: field_for_errors.clone(),
1826 typ: base_ty.to_string(),
1827 });
1828 }
1829 return Err(TypeError::FieldNotKnown {
1830 field: field_for_errors.clone(),
1831 typ: base_ty.to_string(),
1832 });
1833 }
1834
1835 Err(TypeError::UnknownField {
1836 field: field_for_errors.clone(),
1837 typ: base_ty.to_string(),
1838 })
1839}
1840
1841fn resolve_record_update(
1842 unifier: &mut Unifier<'_>,
1843 supply: &mut TypeVarSupply,
1844 adts: &BTreeMap<Symbol, AdtDecl>,
1845 base_ty: &Type,
1846 known_variant: Option<KnownVariant>,
1847 update_fields: &[Symbol],
1848) -> Result<(Type, Vec<(Symbol, Type)>), TypeError> {
1849 if let TypeKind::Record(fields) = base_ty.as_ref() {
1850 return Ok((base_ty.clone(), fields.clone()));
1851 }
1852
1853 let field_for_errors = update_fields.first().cloned().unwrap_or_else(|| sym("_"));
1854
1855 let (adt, variant) =
1856 select_record_variant(adts, base_ty, known_variant, &field_for_errors, |fields| {
1857 update_fields
1858 .iter()
1859 .all(|field| fields.iter().any(|(name, _)| name == field))
1860 })?;
1861
1862 let (result_ty, fields) =
1863 instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1864 TypeError::UnknownField {
1865 field: field_for_errors.clone(),
1866 typ: base_ty.to_string(),
1867 }
1868 })?;
1869
1870 for field in update_fields {
1871 if fields.iter().all(|(name, _)| name != field) {
1872 return Err(TypeError::UnknownField {
1873 field: field.clone(),
1874 typ: base_ty.to_string(),
1875 });
1876 }
1877 }
1878
1879 unifier.unify(base_ty, &result_ty)?;
1880 let result_ty = unifier.apply_type(&result_ty);
1881 let fields = fields
1882 .into_iter()
1883 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
1884 .collect();
1885 Ok((result_ty, fields))
1886}
1887
1888fn resolve_projection(
1889 unifier: &mut Unifier<'_>,
1890 supply: &mut TypeVarSupply,
1891 adts: &BTreeMap<Symbol, AdtDecl>,
1892 base_ty: &Type,
1893 known_variant: Option<KnownVariant>,
1894 field: &Symbol,
1895) -> Result<Type, TypeError> {
1896 if let Ok(index) = field.as_ref().parse::<usize>() {
1897 let elem_ty = match base_ty.as_ref() {
1898 TypeKind::Tuple(elems) => {
1899 elems
1900 .get(index)
1901 .cloned()
1902 .ok_or_else(|| TypeError::UnknownField {
1903 field: field.clone(),
1904 typ: base_ty.to_string(),
1905 })?
1906 }
1907 TypeKind::Var(_) => {
1908 let mut elems = Vec::with_capacity(index + 1);
1909 for _ in 0..=index {
1910 elems.push(Type::var(supply.fresh(Some(sym("t")))));
1911 }
1912 let tuple_ty = Type::tuple(elems.clone());
1913 unifier.unify(base_ty, &tuple_ty)?;
1914 elems[index].clone()
1915 }
1916 _ => {
1917 return Err(TypeError::UnknownField {
1918 field: field.clone(),
1919 typ: base_ty.to_string(),
1920 });
1921 }
1922 };
1923 return Ok(unifier.apply_type(&elem_ty));
1924 }
1925
1926 let (adt, variant) = select_record_variant(adts, base_ty, known_variant, field, |fields| {
1927 fields.iter().any(|(name, _)| name == field)
1928 })?;
1929
1930 let (result_ty, fields) =
1931 instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1932 TypeError::UnknownField {
1933 field: field.clone(),
1934 typ: base_ty.to_string(),
1935 }
1936 })?;
1937 let field_ty = fields
1938 .iter()
1939 .find(|(name, _)| name == field)
1940 .map(|(_, ty)| ty.clone())
1941 .ok_or_else(|| TypeError::UnknownField {
1942 field: field.clone(),
1943 typ: base_ty.to_string(),
1944 })?;
1945 unifier.unify(base_ty, &result_ty)?;
1946 Ok(unifier.apply_type(&field_ty))
1947}
1948
1949fn decompose_fun(typ: &Type, arity: usize) -> Option<(Vec<Type>, Type)> {
1950 let mut args = Vec::with_capacity(arity);
1951 let mut cur = typ.clone();
1952 for _ in 0..arity {
1953 match cur.as_ref() {
1954 TypeKind::Fun(a, b) => {
1955 args.push(a.clone());
1956 cur = b.clone();
1957 }
1958 _ => return None,
1959 }
1960 }
1961 Some((args, cur))
1962}
1963
1964type InferPatternResult = (Vec<Predicate>, Vec<(Symbol, Type)>);
1965
1966fn infer_pattern(
1967 unifier: &mut Unifier<'_>,
1968 supply: &mut TypeVarSupply,
1969 env: &TypeEnv,
1970 pat: &Pattern,
1971 scrutinee_ty: &Type,
1972) -> Result<InferPatternResult, TypeError> {
1973 let span = *pat.span();
1974 let res = (|| {
1975 unifier.charge_infer_node()?;
1976 match pat {
1977 Pattern::Wildcard(..) => Ok((vec![], vec![])),
1978 Pattern::Var(var) => Ok((
1979 vec![],
1980 vec![(var.name.clone(), unifier.apply_type(scrutinee_ty))],
1981 )),
1982 Pattern::Named(_, name, ps) => {
1983 let ctor_name = name.to_dotted_symbol();
1984 let schemes = env
1985 .lookup(&ctor_name)
1986 .ok_or_else(|| TypeError::UnknownVar(ctor_name.clone()))?;
1987 if schemes.len() != 1 {
1988 return Err(TypeError::AmbiguousOverload(ctor_name));
1989 }
1990 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1991 let (preds, ctor_ty) = instantiate(&scheme, supply);
1992 let (arg_tys, res_ty) = decompose_fun(&ctor_ty, ps.len())
1993 .ok_or(TypeError::UnsupportedExpr("pattern constructor"))?;
1994 unifier.unify(&res_ty, scrutinee_ty)?;
1995 let mut all_preds = preds;
1996 let mut bindings = Vec::new();
1997 for (p, arg_ty) in ps.iter().zip(arg_tys.iter()) {
1998 let arg_ty = unifier.apply_type(arg_ty);
1999 let (p1, binds1) = infer_pattern(unifier, supply, env, p, &arg_ty)?;
2000 all_preds.extend(p1);
2001 bindings.extend(binds1);
2002 }
2003 let bindings = bindings
2004 .into_iter()
2005 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2006 .collect();
2007 Ok((all_preds, bindings))
2008 }
2009 Pattern::List(_, ps) => {
2010 let elem_tv = Type::var(supply.fresh(Some("a".into())));
2011 let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2012 unifier.unify(scrutinee_ty, &list_ty)?;
2013 let mut preds = Vec::new();
2014 let mut bindings = Vec::new();
2015 for p in ps {
2016 let elem_ty = unifier.apply_type(&elem_tv);
2017 let (p1, binds1) = infer_pattern(unifier, supply, env, p, &elem_ty)?;
2018 preds.extend(p1);
2019 bindings.extend(binds1);
2020 }
2021 let bindings = bindings
2022 .into_iter()
2023 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2024 .collect();
2025 Ok((preds, bindings))
2026 }
2027 Pattern::Cons(_, head, tail) => {
2028 let elem_tv = Type::var(supply.fresh(Some("a".into())));
2029 let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2030 unifier.unify(scrutinee_ty, &list_ty)?;
2031 let mut preds = Vec::new();
2032 let mut bindings = Vec::new();
2033
2034 let head_ty = unifier.apply_type(&elem_tv);
2035 let (p1, binds1) = infer_pattern(unifier, supply, env, head, &head_ty)?;
2036 preds.extend(p1);
2037 bindings.extend(binds1);
2038
2039 let tail_ty = Type::app(
2040 Type::builtin(BuiltinTypeId::List),
2041 unifier.apply_type(&elem_tv),
2042 );
2043 let (p2, binds2) = infer_pattern(unifier, supply, env, tail, &tail_ty)?;
2044 preds.extend(p2);
2045 bindings.extend(binds2);
2046
2047 let bindings = bindings
2048 .into_iter()
2049 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2050 .collect();
2051 Ok((preds, bindings))
2052 }
2053 Pattern::Tuple(_, elems) => {
2054 let mut elem_tys: Vec<Type> = (0..elems.len())
2055 .map(|i| Type::var(supply.fresh(Some(format!("t{i}").into()))))
2056 .collect();
2057 let expected = Type::tuple(elem_tys.clone());
2058 unifier.unify(scrutinee_ty, &expected)?;
2059 elem_tys = elem_tys
2060 .into_iter()
2061 .map(|t| unifier.apply_type(&t))
2062 .collect();
2063
2064 let mut preds = Vec::new();
2065 let mut bindings = Vec::new();
2066 for (p, ty) in elems.iter().zip(elem_tys.iter()) {
2067 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, p, ty)?;
2068 preds.extend(p_preds);
2069 bindings.extend(p_binds);
2070 }
2071 let bindings = bindings
2072 .into_iter()
2073 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2074 .collect();
2075 Ok((preds, bindings))
2076 }
2077 Pattern::Dict(_, fields) => {
2078 if let TypeKind::Record(ty_fields) = scrutinee_ty.as_ref() {
2079 let mut preds = Vec::new();
2080 let mut bindings = Vec::new();
2081 for (key, pat) in fields {
2082 let ty = ty_fields
2083 .iter()
2084 .find(|(name, _)| name == key)
2085 .map(|(_, ty)| unifier.apply_type(ty))
2086 .ok_or_else(|| TypeError::UnknownField {
2087 field: key.clone(),
2088 typ: scrutinee_ty.to_string(),
2089 })?;
2090 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &ty)?;
2091 preds.extend(p_preds);
2092 bindings.extend(p_binds);
2093 }
2094 let bindings = bindings
2095 .into_iter()
2096 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2097 .collect();
2098 Ok((preds, bindings))
2099 } else {
2100 let elem_tv = Type::var(supply.fresh(Some("v".into())));
2101 let dict_ty = Type::app(Type::builtin(BuiltinTypeId::Dict), elem_tv.clone());
2102 unifier.unify(scrutinee_ty, &dict_ty)?;
2103 let elem_ty = unifier.apply_type(&elem_tv);
2104
2105 let mut preds = Vec::new();
2106 let mut bindings = Vec::new();
2107 for (_key, pat) in fields {
2108 let (p_preds, p_binds) =
2109 infer_pattern(unifier, supply, env, pat, &elem_ty)?;
2110 preds.extend(p_preds);
2111 bindings.extend(p_binds);
2112 }
2113 let bindings = bindings
2114 .into_iter()
2115 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2116 .collect();
2117 Ok((preds, bindings))
2118 }
2119 }
2120 }
2121 })();
2122 res.map_err(|err| err.with_span(&span))
2123}
2124
2125fn type_head_name(typ: &Type) -> Option<&Symbol> {
2126 let mut cur = typ;
2127 while let TypeKind::App(head, _) = cur.as_ref() {
2128 cur = head;
2129 }
2130 match cur.as_ref() {
2131 TypeKind::Con(tc) => Some(&tc.name),
2132 _ => None,
2133 }
2134}
2135
2136fn adt_name_from_patterns(
2137 adts: &BTreeMap<Symbol, AdtDecl>,
2138 patterns: &[Pattern],
2139) -> Option<Symbol> {
2140 let mut candidate: Option<Symbol> = None;
2141 for pat in patterns {
2142 let next = match pat {
2143 Pattern::Named(_, name, _) => {
2144 let name_sym = name.to_dotted_symbol();
2145 ctor_lookup(adts, &name_sym).map(|(adt, _)| adt.name.clone())
2146 }
2147 Pattern::List(..) | Pattern::Cons(..) => Some(sym("List")),
2148 _ => None,
2149 };
2150 if let Some(next) = next {
2151 match &candidate {
2152 None => candidate = Some(next),
2153 Some(prev) if *prev == next => {}
2154 Some(_) => return None,
2155 }
2156 }
2157 }
2158 candidate
2159}
2160
2161fn check_match_exhaustive(
2162 adts: &BTreeMap<Symbol, AdtDecl>,
2163 scrutinee_ty: &Type,
2164 patterns: &[Pattern],
2165) -> Result<(), TypeError> {
2166 if patterns
2167 .iter()
2168 .any(|p| matches!(p, Pattern::Wildcard(..) | Pattern::Var(_)))
2169 {
2170 return Ok(());
2171 }
2172 let adt_name = match type_head_name(scrutinee_ty).cloned() {
2173 Some(name) => name,
2174 None => match adt_name_from_patterns(adts, patterns) {
2175 Some(name) => name,
2176 None => return Ok(()),
2177 },
2178 };
2179 let adt = match adts.get(&adt_name) {
2180 Some(adt) => adt,
2181 None => return Ok(()),
2182 };
2183 let ctor_names: BTreeSet<Symbol> = adt.variants.iter().map(|v| v.name.clone()).collect();
2184 if ctor_names.is_empty() {
2185 return Ok(());
2186 }
2187 let mut covered = BTreeSet::new();
2188 for pat in patterns {
2189 match pat {
2190 Pattern::Named(_, name, _) => {
2191 let name_sym = name.to_dotted_symbol();
2192 if ctor_names.contains(&name_sym) {
2193 covered.insert(name_sym);
2194 }
2195 }
2196 Pattern::List(_, elems) if adt_name.as_ref() == "List" && elems.is_empty() => {
2197 covered.insert(sym("Empty"));
2198 }
2199 Pattern::Cons(..) if adt_name.as_ref() == "List" => {
2200 covered.insert(sym("Cons"));
2201 }
2202 _ => {}
2203 }
2204 }
2205 let mut missing: Vec<Symbol> = ctor_names.difference(&covered).cloned().collect();
2206 if missing.is_empty() {
2207 return Ok(());
2208 }
2209 missing.sort();
2210 Err(TypeError::NonExhaustiveMatch {
2211 typ: scrutinee_ty.to_string(),
2212 missing,
2213 })
2214}
2215
2216#[cfg(test)]
2217mod tests {
2218 use super::*;
2219 use crate::{
2220 types::collect_adts_in_types,
2221 typesystem::{TypeSystemLimits, entails, generalize},
2222 unification::bind,
2223 };
2224 use rex_lexer::{Token, span::Span};
2225 use rex_parser::Parser;
2226 use rex_util::{GasCosts, GasMeter};
2227
2228 fn tvar(id: TypeVarId, name: &str) -> Type {
2229 Type::var(TypeVar::new(id, Some(sym(name))))
2230 }
2231
2232 fn dict_of(elem: Type) -> Type {
2233 Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
2234 }
2235
2236 #[test]
2237 fn unify_simple() {
2238 let t1 = Type::fun(tvar(0, "a"), Type::builtin(BuiltinTypeId::U32));
2239 let t2 = Type::fun(Type::builtin(BuiltinTypeId::U16), tvar(1, "b"));
2240 let subst = unify(&t1, &t2).unwrap();
2241 assert_eq!(subst.get(&0), Some(&Type::builtin(BuiltinTypeId::U16)));
2242 assert_eq!(subst.get(&1), Some(&Type::builtin(BuiltinTypeId::U32)));
2243 }
2244
2245 #[test]
2246 fn occurs_check_blocks_infinite_type() {
2247 let tv = TypeVar::new(0, Some(sym("a")));
2248 let t = Type::fun(Type::var(tv.clone()), Type::builtin(BuiltinTypeId::U8));
2249 let err = bind(&tv, &t).unwrap_err();
2250 assert!(matches!(err, TypeError::Occurs(_, _)));
2251 }
2252
2253 #[test]
2254 fn instantiate_and_generalize_round_trip() {
2255 let mut supply = TypeVarSupply::new();
2256 let a = Type::var(supply.fresh(Some(sym("a"))));
2257 let scheme = generalize(&TypeEnv::new(), vec![], Type::fun(a.clone(), a.clone()));
2258 let (preds, inst) = instantiate(&scheme, &mut supply);
2259 assert!(preds.is_empty());
2260 if let TypeKind::Fun(l, r) = inst.as_ref() {
2261 match (l.as_ref(), r.as_ref()) {
2262 (TypeKind::Var(_), TypeKind::Var(_)) => {}
2263 _ => panic!("expected polymorphic identity"),
2264 }
2265 } else {
2266 panic!("expected function type");
2267 }
2268 }
2269
2270 #[test]
2271 fn entail_superclasses() {
2272 let ts = TypeSystem::new_with_prelude().unwrap();
2273 let pred = Predicate::new("Semiring", Type::builtin(BuiltinTypeId::I32));
2274 let given = [Predicate::new(
2275 "AdditiveGroup",
2276 Type::builtin(BuiltinTypeId::I32),
2277 )];
2278 assert!(entails(&ts.classes, &given, &pred).unwrap());
2279 }
2280
2281 #[test]
2282 fn entail_instances() {
2283 let ts = TypeSystem::new_with_prelude().unwrap();
2284 let pred = Predicate::new("Field", Type::builtin(BuiltinTypeId::F32));
2285 assert!(entails(&ts.classes, &[], &pred).unwrap());
2286
2287 let pred_fail = Predicate::new("Field", Type::builtin(BuiltinTypeId::U32));
2288 assert!(!entails(&ts.classes, &[], &pred_fail).unwrap());
2289 }
2290
2291 #[test]
2292 fn prelude_injects_functions() {
2293 let ts = TypeSystem::new_with_prelude().unwrap();
2294 let minus = ts.env.lookup(&sym("-")).expect("minus in env");
2295 let div = ts.env.lookup(&sym("/")).expect("div in env");
2296 assert_eq!(minus.len(), 1);
2297 assert_eq!(div.len(), 1);
2298 let minus = &minus[0];
2299 let div = &div[0];
2300 assert_eq!(minus.preds.len(), 1);
2301 assert_eq!(minus.vars.len(), 1);
2302 assert_eq!(div.preds.len(), 1);
2303 assert_eq!(div.vars.len(), 1);
2304 }
2305
2306 #[test]
2307 fn adt_constructors_are_present() {
2308 let ts = TypeSystem::new_with_prelude().unwrap();
2309 assert!(ts.env.lookup(&sym("Empty")).is_some());
2310 assert!(ts.env.lookup(&sym("Cons")).is_some());
2311 assert!(ts.env.lookup(&sym("Ok")).is_some());
2312 assert!(ts.env.lookup(&sym("Err")).is_some());
2313 assert!(ts.env.lookup(&sym("Some")).is_some());
2314 assert!(ts.env.lookup(&sym("None")).is_some());
2315 }
2316
2317 fn parse_expr(code: &str) -> std::sync::Arc<rex_ast::expr::Expr> {
2318 let mut parser = Parser::new(Token::tokenize(code).unwrap());
2319 parser.parse_program(&mut GasMeter::default()).unwrap().expr
2320 }
2321
2322 fn parse_program(code: &str) -> rex_ast::expr::Program {
2323 let mut parser = Parser::new(Token::tokenize(code).unwrap());
2324 parser.parse_program(&mut GasMeter::default()).unwrap()
2325 }
2326
2327 #[test]
2328 fn infer_deep_list_does_not_overflow() {
2329 const N: usize = 40;
2330 let mut code = String::new();
2331 code.push_str("let xs = ");
2332 for _ in 0..N {
2333 code.push_str("Cons 0 (");
2334 }
2335 code.push_str("Empty");
2336 for _ in 0..N {
2337 code.push(')');
2338 }
2339 code.push_str(" in xs");
2340
2341 let parse_handle = std::thread::Builder::new()
2342 .name("infer_deep_list_parse".into())
2343 .stack_size(128 * 1024 * 1024)
2344 .spawn(move || {
2345 let tokens = Token::tokenize(&code).unwrap();
2346 let mut parser = Parser::new(tokens);
2347 parser.parse_program(&mut GasMeter::default())
2348 })
2349 .unwrap();
2350 let program = parse_handle.join().unwrap().unwrap();
2351 let expr = program.expr;
2352 let mut ts = TypeSystem::new_with_prelude().unwrap();
2353 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2354 assert_eq!(
2355 ty,
2356 Type::app(
2357 Type::builtin(BuiltinTypeId::List),
2358 Type::builtin(BuiltinTypeId::I32)
2359 )
2360 );
2361 }
2362
2363 #[test]
2364 fn collect_adts_in_types_finds_nested_unique_adts() {
2365 let foo = Type::user_con("Foo", 1);
2366 let bar = Type::user_con("Bar", 0);
2367 let ty = Type::fun(
2368 Type::app(
2369 Type::builtin(BuiltinTypeId::List),
2370 Type::app(foo.clone(), tvar(0, "a")),
2371 ),
2372 Type::tuple(vec![
2373 Type::app(foo.clone(), Type::builtin(BuiltinTypeId::I32)),
2374 bar.clone(),
2375 ]),
2376 );
2377
2378 let adts = collect_adts_in_types(vec![ty]).unwrap();
2379 assert_eq!(adts, vec![foo, bar]);
2380 }
2381
2382 #[test]
2383 fn collect_adts_in_types_rejects_conflicting_names() {
2384 let arity1 = Type::user_con("Thing", 1);
2385 let arity2 = Type::user_con("Thing", 2);
2386
2387 let err = collect_adts_in_types(vec![arity1.clone(), arity2.clone()]).unwrap_err();
2388 assert_eq!(err.conflicts.len(), 1);
2389 let conflict = &err.conflicts[0];
2390 assert_eq!(conflict.name, sym("Thing"));
2391 assert_eq!(conflict.definitions, vec![arity1, arity2]);
2392 }
2393
2394 #[test]
2395 fn infer_depth_limit_is_enforced() {
2396 const N: usize = 40;
2397 let mut code = String::new();
2398 code.push_str("let xs = ");
2399 for _ in 0..N {
2400 code.push_str("Cons 0 (");
2401 }
2402 code.push_str("Empty");
2403 for _ in 0..N {
2404 code.push(')');
2405 }
2406 code.push_str(" in xs");
2407
2408 let program = parse_program(&code);
2409 let mut ts = TypeSystem::new_with_prelude().unwrap();
2410 ts.set_limits(TypeSystemLimits {
2411 max_infer_depth: Some(8),
2412 });
2413
2414 let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
2415 assert!(
2416 err.to_string().contains("maximum inference depth exceeded"),
2417 "expected a max-depth inference error, got: {err:?}"
2418 );
2419 }
2420
2421 #[test]
2422 fn declare_fn_injects_scheme_for_use_sites() {
2423 let program = parse_program(
2424 r#"
2425 declare fn id x: a -> a
2426 id 1
2427 "#,
2428 );
2429 let mut ts = TypeSystem::new_with_prelude().unwrap();
2430 ts.register_decls(&program.decls).unwrap();
2431 let (preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2432 assert!(
2433 preds.is_empty()
2434 || preds.iter().all(|p| p.class.as_ref() == "Integral"
2435 && p.typ == Type::builtin(BuiltinTypeId::I32))
2436 );
2437 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2438 }
2439
2440 #[test]
2441 fn declare_fn_is_noop_when_matching_existing_scheme() {
2442 let mut ts = TypeSystem::new_with_prelude().unwrap();
2443 ts.add_value(
2444 "foo",
2445 Scheme::new(
2446 vec![],
2447 vec![],
2448 Type::fun(
2449 Type::builtin(BuiltinTypeId::I32),
2450 Type::builtin(BuiltinTypeId::I32),
2451 ),
2452 ),
2453 );
2454
2455 let program = parse_program(
2456 r#"
2457 declare fn foo x: i32 -> i32
2458 0
2459 "#,
2460 );
2461 let rex_ast::expr::Decl::DeclareFn(fd) = &program.decls[0] else {
2462 panic!("expected declare fn decl");
2463 };
2464 ts.inject_declare_fn_decl(fd).unwrap();
2465 }
2466
2467 #[test]
2468 fn unit_type_parses_and_infers() {
2469 let program = parse_program(
2470 r#"
2471 fn unit_id x: () -> () = x
2472 unit_id ()
2473 "#,
2474 );
2475 let mut ts = TypeSystem::new_with_prelude().unwrap();
2476 ts.register_decls(&program.decls).unwrap();
2477 let (preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2478 assert!(preds.is_empty());
2479 assert_eq!(ty, Type::tuple(vec![]));
2480 }
2481
2482 fn strip_span(mut err: TypeError) -> TypeError {
2483 while let TypeError::Spanned { error, .. } = err {
2484 err = *error;
2485 }
2486 err
2487 }
2488
2489 #[test]
2490 fn type_errors_include_span() {
2491 let expr = parse_expr("missing");
2492 let mut ts = TypeSystem::new_with_prelude().unwrap();
2493 let err = infer(&mut ts, expr.as_ref()).unwrap_err();
2494 match err {
2495 TypeError::Spanned { span, error } => {
2496 assert_ne!(span, Span::default());
2497 assert!(matches!(
2498 *error,
2499 TypeError::UnknownVar(name) if name.as_ref() == "missing"
2500 ));
2501 }
2502 other => panic!("expected spanned error, got {other:?}"),
2503 }
2504 }
2505
2506 #[test]
2507 fn infer_with_gas_rejects_out_of_budget() {
2508 let expr = parse_expr("1");
2509 let mut ts = TypeSystem::new_with_prelude().unwrap();
2510 let mut gas = GasMeter::new(
2511 Some(0),
2512 GasCosts {
2513 infer_node: 1,
2514 unify_step: 0,
2515 ..GasCosts::sensible_defaults()
2516 },
2517 );
2518 let err = infer_with_gas(&mut ts, expr.as_ref(), &mut gas).unwrap_err();
2519 assert!(matches!(strip_span(err), TypeError::OutOfGas(..)));
2520 }
2521
2522 #[test]
2523 fn reject_user_redefinition_of_primitive_type_name() {
2524 let program = parse_program("type i32 = I32Wrap i32");
2525 let mut ts = TypeSystem::new_with_prelude().unwrap();
2526 let rex_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2527 panic!("expected type decl");
2528 };
2529 let err = ts.register_type_decl(decl).unwrap_err();
2530 assert!(matches!(
2531 err,
2532 TypeError::ReservedTypeName(name) if name.as_ref() == "i32"
2533 ));
2534 }
2535
2536 #[test]
2537 fn reject_user_redefinition_of_prelude_adt_name() {
2538 let program = parse_program("type Result e a = Nope e a");
2539 let mut ts = TypeSystem::new_with_prelude().unwrap();
2540 let rex_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2541 panic!("expected type decl");
2542 };
2543 let err = ts.register_type_decl(decl).unwrap_err();
2544 assert!(matches!(
2545 err,
2546 TypeError::ReservedTypeName(name) if name.as_ref() == "Result"
2547 ));
2548 }
2549
2550 #[test]
2551 fn reject_user_redefinition_of_promise_type_name() {
2552 let program = parse_program("type Promise a = PromiseWrap a");
2553 let mut ts = TypeSystem::new_with_prelude().unwrap();
2554 let rex_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2555 panic!("expected type decl");
2556 };
2557 let err = ts.register_type_decl(decl).unwrap_err();
2558 assert!(matches!(
2559 err,
2560 TypeError::ReservedTypeName(name) if name.as_ref() == "Promise"
2561 ));
2562 }
2563
2564 #[test]
2565 fn infer_polymorphic_id_tuple() {
2566 let expr = parse_expr(
2567 r#"
2568 let
2569 id = \x -> x
2570 in
2571 id (id 420, id 6.9, id "str")
2572 "#,
2573 );
2574 let mut ts = TypeSystem::new_with_prelude().unwrap();
2575 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2576 let expected = Type::tuple(vec![
2577 Type::builtin(BuiltinTypeId::I32),
2578 Type::builtin(BuiltinTypeId::F32),
2579 Type::builtin(BuiltinTypeId::String),
2580 ]);
2581 assert_eq!(ty, expected);
2582 }
2583
2584 #[test]
2585 fn infer_type_annotation_ok() {
2586 let expr = parse_expr("let x: i32 = 42 in x");
2587 let mut ts = TypeSystem::new_with_prelude().unwrap();
2588 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2589 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2590 }
2591
2592 #[test]
2593 fn infer_type_annotation_lambda_param() {
2594 let expr = parse_expr("\\ (a : f32) -> a");
2595 let mut ts = TypeSystem::new_with_prelude().unwrap();
2596 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2597 assert_eq!(
2598 ty,
2599 Type::fun(
2600 Type::builtin(BuiltinTypeId::F32),
2601 Type::builtin(BuiltinTypeId::F32)
2602 )
2603 );
2604 }
2605
2606 #[test]
2607 fn infer_type_annotation_is_alias() {
2608 let expr = parse_expr("\"hi\" is str");
2609 let mut ts = TypeSystem::new_with_prelude().unwrap();
2610 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2611 assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
2612 }
2613
2614 #[test]
2615 fn infer_type_annotation_with_promise_constructor() {
2616 let expr = parse_expr("\\(x: Promise i32) -> x");
2617 let mut ts = TypeSystem::new_with_prelude().unwrap();
2618 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2619 let promise_i32 = Type::promise(Type::builtin(BuiltinTypeId::I32));
2620 assert_eq!(ty, Type::fun(promise_i32.clone(), promise_i32));
2621 }
2622
2623 #[test]
2624 fn infer_type_annotation_mismatch_error() {
2625 let expr = parse_expr("let x: i32 = 3.14 in x");
2626 let mut ts = TypeSystem::new_with_prelude().unwrap();
2627 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
2628 assert!(matches!(err, TypeError::Unification(_, _)));
2629 }
2630
2631 #[test]
2632 fn infer_project_single_variant_let() {
2633 let program = parse_program(
2634 r#"
2635 type MyADT = MyVariant1 { field1: i32, field2: f32 }
2636 let
2637 x = MyVariant1 { field1 = 1, field2 = 2.0 }
2638 in
2639 (x.field1, x.field2)
2640 "#,
2641 );
2642 let mut ts = TypeSystem::new_with_prelude().unwrap();
2643 for decl in &program.decls {
2644 if let rex_ast::expr::Decl::Type(decl) = decl {
2645 ts.register_type_decl(decl).unwrap();
2646 }
2647 }
2648 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2649 let expected = Type::tuple(vec![
2650 Type::builtin(BuiltinTypeId::I32),
2651 Type::builtin(BuiltinTypeId::F32),
2652 ]);
2653 assert_eq!(ty, expected);
2654 }
2655
2656 #[test]
2657 fn infer_project_known_variant_let() {
2658 let program = parse_program(
2659 r#"
2660 type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
2661 let
2662 x = MyVariant1 { field1 = 1, field2 = 2.0 }
2663 in
2664 x.field1
2665 "#,
2666 );
2667 let mut ts = TypeSystem::new_with_prelude().unwrap();
2668 for decl in &program.decls {
2669 if let rex_ast::expr::Decl::Type(decl) = decl {
2670 ts.register_type_decl(decl).unwrap();
2671 }
2672 }
2673 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2674 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2675 }
2676
2677 #[test]
2678 fn infer_project_unknown_variant_error() {
2679 let program = parse_program(
2680 r#"
2681 type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
2682 let
2683 x = MyVariant2 1 2.0
2684 in
2685 x.field1
2686 "#,
2687 );
2688 let mut ts = TypeSystem::new_with_prelude().unwrap();
2689 for decl in &program.decls {
2690 if let rex_ast::expr::Decl::Type(decl) = decl {
2691 ts.register_type_decl(decl).unwrap();
2692 }
2693 }
2694 let err = strip_span(infer(&mut ts, program.expr.as_ref()).unwrap_err());
2695 assert!(matches!(err, TypeError::FieldNotKnown { .. }));
2696 }
2697
2698 #[test]
2699 fn infer_project_lambda_param_single_variant() {
2700 let program = parse_program(
2701 r#"
2702 type Boxed = Boxed { value: i32 }
2703 let
2704 f = \x -> x.value
2705 in
2706 f (Boxed { value = 1 })
2707 "#,
2708 );
2709 let mut ts = TypeSystem::new_with_prelude().unwrap();
2710 for decl in &program.decls {
2711 if let rex_ast::expr::Decl::Type(decl) = decl {
2712 ts.register_type_decl(decl).unwrap();
2713 }
2714 }
2715 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2716 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2717 }
2718
2719 #[test]
2720 fn infer_project_in_match_arm() {
2721 let program = parse_program(
2722 r#"
2723 type MyADT = MyVariant1 { field1: i32 } | MyVariant2 i32
2724 let
2725 x = MyVariant1 { field1 = 1 }
2726 in
2727 match x
2728 when MyVariant1 { field1 } -> x.field1
2729 when MyVariant2 _ -> 0
2730 "#,
2731 );
2732 let mut ts = TypeSystem::new_with_prelude().unwrap();
2733 for decl in &program.decls {
2734 if let rex_ast::expr::Decl::Type(decl) = decl {
2735 ts.register_type_decl(decl).unwrap();
2736 }
2737 }
2738 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2739 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2740 }
2741
2742 #[test]
2743 fn infer_nested_let_lambda_match_option() {
2744 let expr = parse_expr(
2745 r#"
2746 let
2747 choose = \flag a b -> if flag then a else b,
2748 build = \flag ->
2749 let
2750 pick = choose flag,
2751 val = pick 1 2
2752 in
2753 Some val
2754 in
2755 match (build true)
2756 when Some x -> x
2757 when None -> 0
2758 "#,
2759 );
2760 let mut ts = TypeSystem::new_with_prelude().unwrap();
2761 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2762 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2763 }
2764
2765 #[test]
2766 fn infer_polymorphic_apply_in_tuple() {
2767 let expr = parse_expr(
2768 r#"
2769 let
2770 apply = \f x -> f x,
2771 id = \x -> x,
2772 wrap = \x -> (x, x)
2773 in
2774 (apply id 1, apply id "hi", apply wrap true)
2775 "#,
2776 );
2777 let mut ts = TypeSystem::new_with_prelude().unwrap();
2778 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2779 let expected = Type::tuple(vec![
2780 Type::builtin(BuiltinTypeId::I32),
2781 Type::builtin(BuiltinTypeId::String),
2782 Type::tuple(vec![
2783 Type::builtin(BuiltinTypeId::Bool),
2784 Type::builtin(BuiltinTypeId::Bool),
2785 ]),
2786 ]);
2787 assert_eq!(ty, expected);
2788 }
2789
2790 #[test]
2791 fn infer_nested_result_option_match() {
2792 let expr = parse_expr(
2793 r#"
2794 let
2795 unwrap = \x ->
2796 match x
2797 when Ok (Some v) -> v
2798 when Ok None -> 0
2799 when Err _ -> 0
2800 in
2801 unwrap (Ok (Some 5))
2802 "#,
2803 );
2804 let mut ts = TypeSystem::new_with_prelude().unwrap();
2805 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2806 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2807 }
2808
2809 #[test]
2810 fn infer_head_or_list_match() {
2811 let expr = parse_expr(
2812 r#"
2813 let
2814 head_or = \fallback xs ->
2815 match xs
2816 when [] -> fallback
2817 when x::xs -> x
2818 in
2819 (head_or 0 [1, 2, 3], head_or 0 [])
2820 "#,
2821 );
2822 let mut ts = TypeSystem::new_with_prelude().unwrap();
2823 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2824 let expected = Type::tuple(vec![
2825 Type::builtin(BuiltinTypeId::I32),
2826 Type::builtin(BuiltinTypeId::I32),
2827 ]);
2828 assert_eq!(ty, expected);
2829 }
2830
2831 #[test]
2832 fn infer_head_or_list_match_cons_constructor_form() {
2833 let expr = parse_expr(
2834 r#"
2835 let
2836 head_or = \fallback xs ->
2837 match xs
2838 when [] -> fallback
2839 when Cons x xs1 -> x
2840 in
2841 (head_or 0 (Cons 1 (Cons 2 Empty)), head_or 0 Empty)
2842 "#,
2843 );
2844 let mut ts = TypeSystem::new_with_prelude().unwrap();
2845 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2846 let expected = Type::tuple(vec![
2847 Type::builtin(BuiltinTypeId::I32),
2848 Type::builtin(BuiltinTypeId::I32),
2849 ]);
2850 assert_eq!(ty, expected);
2851 }
2852
2853 #[test]
2854 fn infer_record_pattern_in_lambda() {
2855 let program = parse_program(
2856 r#"
2857 type Pair = Pair { left: i32, right: i32 }
2858 let
2859 sum = \p ->
2860 match p
2861 when Pair { left, right } -> left + right
2862 in
2863 sum (Pair { left = 1, right = 2 })
2864 "#,
2865 );
2866 let mut ts = TypeSystem::new_with_prelude().unwrap();
2867 for decl in &program.decls {
2868 if let rex_ast::expr::Decl::Type(decl) = decl {
2869 ts.register_type_decl(decl).unwrap();
2870 }
2871 }
2872 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2873 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2874 }
2875
2876 #[test]
2877 fn infer_fn_decl_simple() {
2878 let program = parse_program(
2879 r#"
2880 fn add (x: i32, y: i32) -> i32 = x + y
2881 add 1 2
2882 "#,
2883 );
2884 let mut ts = TypeSystem::new_with_prelude().unwrap();
2885 let expr = program.expr_with_fns();
2886 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2887 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2888 }
2889
2890 #[test]
2891 fn infer_fn_decl_signature_form() {
2892 let program = parse_program(
2893 r#"
2894 fn add : i32 -> i32 -> i32 = \x y -> x + y
2895 add 1 2
2896 "#,
2897 );
2898 let mut ts = TypeSystem::new_with_prelude().unwrap();
2899 let expr = program.expr_with_fns();
2900 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2901 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2902 }
2903
2904 #[test]
2905 fn infer_fn_decl_polymorphic_where_constraints() {
2906 let program = parse_program(
2907 r#"
2908 fn my_add (x: a, y: a) -> a where AdditiveMonoid a = x + y
2909 (my_add 1 2, my_add 1.0 2.0)
2910 "#,
2911 );
2912 let mut ts = TypeSystem::new_with_prelude().unwrap();
2913 let expr = program.expr_with_fns();
2914 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2915 assert_eq!(
2916 ty,
2917 Type::tuple(vec![
2918 Type::builtin(BuiltinTypeId::I32),
2919 Type::builtin(BuiltinTypeId::F32)
2920 ])
2921 );
2922 }
2923
2924 #[test]
2925 fn infer_additive_monoid_constraint() {
2926 let expr = parse_expr("\\x y -> x + y");
2927 let mut ts = TypeSystem::new_with_prelude().unwrap();
2928 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2929 assert_eq!(preds.len(), 1);
2930 assert_eq!(preds[0].class.as_ref(), "AdditiveMonoid");
2931
2932 if let TypeKind::Fun(a, rest) = ty.as_ref()
2933 && let TypeKind::Fun(b, c) = rest.as_ref()
2934 {
2935 assert_eq!(a.as_ref(), b.as_ref());
2936 assert_eq!(b.as_ref(), c.as_ref());
2937 assert_eq!(preds[0].typ, a.clone());
2938 return;
2939 }
2940 panic!("expected a -> a -> a");
2941 }
2942
2943 #[test]
2944 fn infer_multiplicative_monoid_constraint() {
2945 let expr = parse_expr("\\x y -> x * y");
2946 let mut ts = TypeSystem::new_with_prelude().unwrap();
2947 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2948 assert_eq!(preds.len(), 1);
2949 assert_eq!(preds[0].class.as_ref(), "MultiplicativeMonoid");
2950
2951 if let TypeKind::Fun(a, rest) = ty.as_ref()
2952 && let TypeKind::Fun(b, c) = rest.as_ref()
2953 {
2954 assert_eq!(a.as_ref(), b.as_ref());
2955 assert_eq!(b.as_ref(), c.as_ref());
2956 assert_eq!(preds[0].typ, a.clone());
2957 return;
2958 }
2959 panic!("expected a -> a -> a");
2960 }
2961
2962 #[test]
2963 fn infer_additive_group_constraint() {
2964 let expr = parse_expr("\\x y -> x - y");
2965 let mut ts = TypeSystem::new_with_prelude().unwrap();
2966 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2967 assert_eq!(preds.len(), 1);
2968 assert_eq!(preds[0].class.as_ref(), "AdditiveGroup");
2969
2970 if let TypeKind::Fun(a, rest) = ty.as_ref()
2971 && let TypeKind::Fun(b, c) = rest.as_ref()
2972 {
2973 assert_eq!(a.as_ref(), b.as_ref());
2974 assert_eq!(b.as_ref(), c.as_ref());
2975 assert_eq!(preds[0].typ, a.clone());
2976 return;
2977 }
2978 panic!("expected a -> a -> a");
2979 }
2980
2981 #[test]
2982 fn infer_integral_constraint() {
2983 let expr = parse_expr("\\x y -> x % y");
2984 let mut ts = TypeSystem::new_with_prelude().unwrap();
2985 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2986 assert_eq!(preds.len(), 1);
2987 assert_eq!(preds[0].class.as_ref(), "Integral");
2988
2989 if let TypeKind::Fun(a, rest) = ty.as_ref()
2990 && let TypeKind::Fun(b, c) = rest.as_ref()
2991 {
2992 assert_eq!(a.as_ref(), b.as_ref());
2993 assert_eq!(b.as_ref(), c.as_ref());
2994 assert_eq!(preds[0].typ, a.clone());
2995 return;
2996 }
2997 panic!("expected a -> a -> a");
2998 }
2999
3000 #[test]
3001 fn infer_literal_addition_defaults() {
3002 let expr = parse_expr("1 + 2");
3003 let mut ts = TypeSystem::new_with_prelude().unwrap();
3004 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3005 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3006 assert_eq!(preds.len(), 2);
3007 assert!(preds.iter().any(|p| p.class.as_ref() == "AdditiveMonoid"));
3008 assert!(preds.iter().any(|p| p.class.as_ref() == "Integral"));
3009 assert!(
3010 preds
3011 .iter()
3012 .all(|p| p.typ == Type::builtin(BuiltinTypeId::I32))
3013 );
3014 }
3015
3016 #[test]
3017 fn infer_mod_defaults() {
3018 let expr = parse_expr("1 % 2");
3019 let mut ts = TypeSystem::new_with_prelude().unwrap();
3020 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3021 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3022 assert_eq!(preds.len(), 1);
3023 assert_eq!(preds[0].class.as_ref(), "Integral");
3024 assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::I32));
3025 }
3026
3027 #[test]
3028 fn infer_get_list_type() {
3029 let expr = parse_expr("get 1 [1, 2, 3]");
3030 let mut ts = TypeSystem::new_with_prelude().unwrap();
3031 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3032 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3033 assert!(preds.iter().any(|p| p.class.as_ref() == "Indexable"));
3034 assert!(preds.iter().all(|p| {
3035 p.class.as_ref() == "Indexable"
3036 || (p.class.as_ref() == "Integral" && p.typ == Type::builtin(BuiltinTypeId::I32))
3037 }));
3038 for pred in preds.iter().filter(|p| p.class.as_ref() == "Indexable") {
3039 assert!(entails(&ts.classes, &[], pred).unwrap());
3040 }
3041 }
3042
3043 #[test]
3044 fn infer_get_tuple_type() {
3045 let expr = parse_expr("(1, 'Hello', true).0");
3046 let mut ts = TypeSystem::new_with_prelude().unwrap();
3047 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3048 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3049 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3050
3051 let expr = parse_expr("(1, 'Hello', true).1");
3052 let mut ts = TypeSystem::new_with_prelude().unwrap();
3053 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3054 assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
3055 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3056
3057 let expr = parse_expr("(1, 'Hello', true).2");
3058 let mut ts = TypeSystem::new_with_prelude().unwrap();
3059 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3060 assert_eq!(ty, Type::builtin(BuiltinTypeId::Bool));
3061 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3062 }
3063
3064 #[test]
3065 fn infer_division_defaults() {
3066 let expr = parse_expr("1.0 / 2.0");
3067 let mut ts = TypeSystem::new_with_prelude().unwrap();
3068 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3069 assert_eq!(ty, Type::builtin(BuiltinTypeId::F32));
3070 assert_eq!(preds.len(), 1);
3071 assert_eq!(preds[0].class.as_ref(), "Field");
3072 assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::F32));
3073 assert!(entails(&ts.classes, &[], &preds[0]).unwrap());
3074 }
3075
3076 #[test]
3077 fn infer_unbound_variable_error() {
3078 let expr = parse_expr("missing");
3079 let mut ts = TypeSystem::new_with_prelude().unwrap();
3080 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3081 assert!(matches!(
3082 err,
3083 TypeError::UnknownVar(name) if name.as_ref() == "missing"
3084 ));
3085 }
3086
3087 #[test]
3088 fn infer_if_branch_type_mismatch_error() {
3089 let expr = parse_expr(r#"if true then 1 else "no""#);
3090 let mut ts = TypeSystem::new_with_prelude().unwrap();
3091 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3092 match err {
3093 TypeError::Unification(a, b) => {
3094 let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
3095 assert!(ok, "expected i32 vs string, got {a} vs {b}");
3096 }
3097 other => panic!("expected unification error, got {other:?}"),
3098 }
3099 }
3100
3101 #[test]
3102 fn infer_unknown_pattern_constructor_error() {
3103 let expr = parse_expr("match 1 when Nope -> 1");
3104 let mut ts = TypeSystem::new_with_prelude().unwrap();
3105 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3106 assert!(matches!(
3107 err,
3108 TypeError::UnknownVar(name) if name.as_ref() == "Nope"
3109 ));
3110 }
3111
3112 #[test]
3113 fn infer_ambiguous_overload_error() {
3114 let mut ts = TypeSystem::new();
3115 let a = TypeVar::new(0, Some(sym("a")));
3116 let b = TypeVar::new(1, Some(sym("b")));
3117 let scheme_a = Scheme::new(vec![a.clone()], vec![], Type::var(a));
3118 let scheme_b = Scheme::new(vec![b.clone()], vec![], Type::var(b));
3119 ts.add_overload(sym("dup"), scheme_a);
3120 ts.add_overload(sym("dup"), scheme_b);
3121 let expr = parse_expr("dup");
3122 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3123 assert!(matches!(
3124 err,
3125 TypeError::AmbiguousOverload(name) if name.as_ref() == "dup"
3126 ));
3127 }
3128
3129 #[test]
3130 fn infer_if_cond_not_bool_error() {
3131 let expr = parse_expr("if 1 then 2 else 3");
3132 let mut ts = TypeSystem::new_with_prelude().unwrap();
3133 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3134 match err {
3135 TypeError::Unification(a, b) => {
3136 let ok = (a == "bool" && b == "i32") || (a == "i32" && b == "bool");
3137 assert!(ok, "expected bool vs i32, got {a} vs {b}");
3138 }
3139 other => panic!("expected unification error, got {other:?}"),
3140 }
3141 }
3142
3143 #[test]
3144 fn infer_apply_non_function_error() {
3145 let expr = parse_expr("1 2");
3146 let mut ts = TypeSystem::new_with_prelude().unwrap();
3147 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3148 assert!(matches!(err, TypeError::Unification(_, _)));
3149 }
3150
3151 #[test]
3152 fn infer_list_element_mismatch_error() {
3153 let expr = parse_expr("[1, true]");
3154 let mut ts = TypeSystem::new_with_prelude().unwrap();
3155 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3156 match err {
3157 TypeError::Unification(a, b) => {
3158 let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
3159 assert!(ok, "expected i32 vs bool, got {a} vs {b}");
3160 }
3161 other => panic!("expected unification error, got {other:?}"),
3162 }
3163 }
3164
3165 #[test]
3166 fn infer_dict_value_mismatch_error() {
3167 let expr = parse_expr("{a = 1, b = true}");
3168 let mut ts = TypeSystem::new_with_prelude().unwrap();
3169 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3170 match err {
3171 TypeError::Unification(a, b) => {
3172 let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
3173 assert!(ok, "expected i32 vs bool, got {a} vs {b}");
3174 }
3175 other => panic!("expected unification error, got {other:?}"),
3176 }
3177 }
3178
3179 #[test]
3180 fn infer_match_list_on_non_list_error() {
3181 let expr = parse_expr("match 1 when [x] -> x");
3182 let mut ts = TypeSystem::new_with_prelude().unwrap();
3183 assert!(infer(&mut ts, expr.as_ref()).is_err());
3184 }
3185
3186 #[test]
3187 fn infer_pattern_constructor_arity_error() {
3188 let expr = parse_expr("match (Ok 1) when Ok x y -> x");
3189 let mut ts = TypeSystem::new_with_prelude().unwrap();
3190 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3191 assert!(matches!(
3192 err,
3193 TypeError::UnsupportedExpr("pattern constructor")
3194 ));
3195 }
3196
3197 #[test]
3198 fn infer_match_arm_type_mismatch_error() {
3199 let expr = parse_expr(r#"match 1 when _ -> 1 when _ -> "no""#);
3200 let mut ts = TypeSystem::new_with_prelude().unwrap();
3201 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3202 match err {
3203 TypeError::Unification(a, b) => {
3204 let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
3205 assert!(ok, "expected i32 vs string, got {a} vs {b}");
3206 }
3207 other => panic!("expected unification error, got {other:?}"),
3208 }
3209 }
3210
3211 #[test]
3212 fn infer_match_option_on_non_option_error() {
3213 let expr = parse_expr("match 1 when Some x -> x");
3214 let mut ts = TypeSystem::new_with_prelude().unwrap();
3215 assert!(infer(&mut ts, expr.as_ref()).is_err());
3216 }
3217
3218 #[test]
3219 fn infer_dict_pattern_on_non_dict_error() {
3220 let expr = parse_expr("match 1 when {a} -> a");
3221 let mut ts = TypeSystem::new_with_prelude().unwrap();
3222 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3223 assert!(matches!(err, TypeError::Unification(_, _)));
3224 }
3225
3226 #[test]
3227 fn infer_cons_pattern_on_non_list_error() {
3228 let expr = parse_expr("match 1 when x::xs -> x");
3229 let mut ts = TypeSystem::new_with_prelude().unwrap();
3230 assert!(infer(&mut ts, expr.as_ref()).is_err());
3231 }
3232
3233 #[test]
3234 fn infer_apply_wrong_arg_type_error() {
3235 let expr = parse_expr("(\\x -> x + 1) true");
3236 let mut ts = TypeSystem::new_with_prelude().unwrap();
3237 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3238 assert!(matches!(err, TypeError::Unification(_, _)));
3239 }
3240
3241 #[test]
3242 fn infer_self_application_occurs_error() {
3243 let expr = parse_expr("\\x -> x x");
3244 let mut ts = TypeSystem::new_with_prelude().unwrap();
3245 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3246 assert!(matches!(err, TypeError::Occurs(_, _)));
3247 }
3248
3249 #[test]
3250 fn infer_apply_constructor_too_many_args_error() {
3251 let expr = parse_expr("Some 1 2");
3252 let mut ts = TypeSystem::new_with_prelude().unwrap();
3253 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3254 assert!(matches!(err, TypeError::Unification(_, _)));
3255 }
3256
3257 #[test]
3258 fn infer_operator_type_mismatch_error() {
3259 let expr = parse_expr("1 + true");
3260 let mut ts = TypeSystem::new_with_prelude().unwrap();
3261 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3262 assert!(matches!(err, TypeError::Unification(_, _)));
3263 }
3264
3265 #[test]
3266 fn infer_non_exhaustive_match_is_error() {
3267 let expr = parse_expr("match (Ok 1) when Ok x -> x");
3268 let mut ts = TypeSystem::new_with_prelude().unwrap();
3269 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3270 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3271 }
3272
3273 #[test]
3274 fn infer_non_exhaustive_match_on_bound_var_error() {
3275 let expr = parse_expr("let x = Ok 1 in match x when Ok y -> y");
3276 let mut ts = TypeSystem::new_with_prelude().unwrap();
3277 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3278 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3279 }
3280
3281 #[test]
3282 fn infer_non_exhaustive_match_in_lambda_error() {
3283 let expr = parse_expr("\\x -> match x when Ok y -> y");
3284 let mut ts = TypeSystem::new_with_prelude().unwrap();
3285 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3286 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3287 }
3288
3289 #[test]
3290 fn infer_non_exhaustive_option_match_error() {
3291 let expr = parse_expr("match (Some 1) when Some x -> x");
3292 let mut ts = TypeSystem::new_with_prelude().unwrap();
3293 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3294 match err {
3295 TypeError::NonExhaustiveMatch { missing, .. } => {
3296 assert_eq!(missing, vec![sym("None")]);
3297 }
3298 other => panic!("expected non-exhaustive match, got {other:?}"),
3299 }
3300 }
3301
3302 #[test]
3303 fn infer_non_exhaustive_result_match_error() {
3304 let expr = parse_expr("match (Err 1) when Ok x -> x");
3305 let mut ts = TypeSystem::new_with_prelude().unwrap();
3306 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3307 match err {
3308 TypeError::NonExhaustiveMatch { missing, .. } => {
3309 assert_eq!(missing, vec![sym("Err")]);
3310 }
3311 other => panic!("expected non-exhaustive match, got {other:?}"),
3312 }
3313 }
3314
3315 #[test]
3316 fn infer_non_exhaustive_list_missing_empty_error() {
3317 let expr = parse_expr("match [1, 2] when x::xs -> x");
3318 let mut ts = TypeSystem::new_with_prelude().unwrap();
3319 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3320 match err {
3321 TypeError::NonExhaustiveMatch { missing, .. } => {
3322 assert_eq!(missing, vec![sym("Empty")]);
3323 }
3324 other => panic!("expected non-exhaustive match, got {other:?}"),
3325 }
3326 }
3327
3328 #[test]
3329 fn infer_non_exhaustive_list_match_on_bound_var_error() {
3330 let expr = parse_expr("let xs = [1, 2] in match xs when x::xs -> x");
3331 let mut ts = TypeSystem::new_with_prelude().unwrap();
3332 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3333 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3334 }
3335
3336 #[test]
3337 fn infer_non_exhaustive_list_missing_cons_error() {
3338 let expr = parse_expr("match [1] when [] -> 0");
3339 let mut ts = TypeSystem::new_with_prelude().unwrap();
3340 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3341 match err {
3342 TypeError::NonExhaustiveMatch { missing, .. } => {
3343 assert_eq!(missing, vec![sym("Cons")]);
3344 }
3345 other => panic!("expected non-exhaustive match, got {other:?}"),
3346 }
3347 }
3348
3349 #[test]
3350 fn infer_match_list_patterns_on_result_error() {
3351 let expr = parse_expr("match (Ok 1) when [] -> 0 when x::xs -> 1");
3352 let mut ts = TypeSystem::new_with_prelude().unwrap();
3353 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3354 assert!(matches!(err, TypeError::Unification(_, _)));
3355 }
3356
3357 #[test]
3358 fn infer_missing_instances_produce_unsatisfied_predicates() {
3359 for (name, code) in [
3360 ("division", "1 / 2"),
3361 ("eq_dict", "{a = 1} == {a = 2}"),
3362 ("min_bool", "min [true]"),
3363 ("map_dict", r#"map (\x -> x) {a = 1}"#),
3364 ] {
3365 let (class, pred_type, expected_ty) = match name {
3366 "division" => (
3367 "Field",
3368 Type::builtin(BuiltinTypeId::I32),
3369 Some(Type::builtin(BuiltinTypeId::I32)),
3370 ),
3371 "eq_dict" => ("Eq", dict_of(Type::builtin(BuiltinTypeId::I32)), None),
3372 "min_bool" => ("Ord", Type::builtin(BuiltinTypeId::Bool), None),
3373 "map_dict" => ("Functor", Type::builtin(BuiltinTypeId::Dict), None),
3374 _ => unreachable!("unknown test case {name}"),
3375 };
3376
3377 let expr = parse_expr(code);
3378 let mut ts = TypeSystem::new_with_prelude().unwrap();
3379 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3380 if let Some(expected) = expected_ty {
3381 assert_eq!(ty, expected, "{name}");
3382 }
3383
3384 let pred = preds
3385 .iter()
3386 .find(|p| p.class.as_ref() == class && p.typ == pred_type)
3387 .unwrap();
3388 assert!(!entails(&ts.classes, &[], pred).unwrap(), "{name}");
3389 }
3390 }
3391
3392 #[test]
3393 fn record_update_single_variant_adt_infers() {
3394 let program = parse_program(
3395 r#"
3396 type Foo = Bar { x: i32, y: i32 }
3397 let
3398 foo: Foo = Bar { x = 1, y = 2 },
3399 bar = { foo with { x = 3 } }
3400 in
3401 bar
3402 "#,
3403 );
3404 let mut ts = TypeSystem::new_with_prelude().unwrap();
3405 ts.register_decls(&program.decls).unwrap();
3406 let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3407 assert_eq!(typ.to_string(), "Foo");
3408 }
3409
3410 #[test]
3411 fn record_update_unknown_field_errors() {
3412 let program = parse_program(
3413 r#"
3414 type Foo = Bar { x: i32 }
3415 let
3416 foo: Foo = Bar { x = 1 }
3417 in
3418 { foo with { y = 2 } }
3419 "#,
3420 );
3421 let mut ts = TypeSystem::new_with_prelude().unwrap();
3422 ts.register_decls(&program.decls).unwrap();
3423 let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
3424 let err = strip_span(err);
3425 assert!(matches!(err, TypeError::UnknownField { .. }));
3426 }
3427
3428 #[test]
3429 fn record_update_requires_refined_variant_for_sum_types() {
3430 let program = parse_program(
3431 r#"
3432 type Foo = Bar { x: i32 } | Baz { x: i32 }
3433 let
3434 f = \ (foo : Foo) -> { foo with { x = 2 } }
3435 in
3436 f (Bar { x = 1 })
3437 "#,
3438 );
3439 let mut ts = TypeSystem::new_with_prelude().unwrap();
3440 ts.register_decls(&program.decls).unwrap();
3441 let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
3442 let err = strip_span(err);
3443 assert!(matches!(err, TypeError::FieldNotKnown { .. }));
3444 }
3445
3446 #[test]
3447 fn record_update_allowed_after_match_refines_variant() {
3448 let program = parse_program(
3449 r#"
3450 type Foo = Bar { x: i32 } | Baz { x: i32 }
3451 let
3452 f = \ (foo : Foo) ->
3453 match foo
3454 when Bar {x} -> { foo with { x = x + 1 } }
3455 when Baz {x} -> { foo with { x = x + 2 } }
3456 in
3457 f (Bar { x = 1 })
3458 "#,
3459 );
3460 let mut ts = TypeSystem::new_with_prelude().unwrap();
3461 ts.register_decls(&program.decls).unwrap();
3462 let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3463 assert_eq!(typ.to_string(), "Foo");
3464 }
3465
3466 #[test]
3467 fn record_update_plain_record_type() {
3468 let program = parse_program(
3469 r#"
3470 let
3471 f = \ (r : { x: i32, y: i32 }) -> { r with { y = 9 } }
3472 in
3473 f { x = 1, y = 2 }
3474 "#,
3475 );
3476 let mut ts = TypeSystem::new_with_prelude().unwrap();
3477 ts.register_decls(&program.decls).unwrap();
3478 let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3479 assert_eq!(typ.to_string(), "{x: i32, y: i32}");
3480 }
3481
3482 #[test]
3483 fn infer_typed_hole_expr_is_hole_kind() {
3484 let expr = parse_expr("?");
3485 let mut ts = TypeSystem::new_with_prelude().unwrap();
3486 let (typed, _preds, _ty) = infer_typed(&mut ts, expr.as_ref()).unwrap();
3487 assert!(
3488 matches!(typed.kind, TypedExprKind::Hole),
3489 "typed={typed:#?}"
3490 );
3491 }
3492
3493 #[test]
3494 fn infer_hole_with_annotation_unifies_to_annotation() {
3495 let expr = parse_expr("let x : i32 = ? in x");
3496 let mut ts = TypeSystem::new_with_prelude().unwrap();
3497 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3498 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3499 }
3500
3501 #[test]
3502 fn infer_hole_in_if_condition_is_bool_constrained() {
3503 let expr = parse_expr("if ? then 1 else 2");
3504 let mut ts = TypeSystem::new_with_prelude().unwrap();
3505 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3506 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3507 }
3508
3509 #[test]
3510 fn infer_hole_in_arithmetic_is_numeric_constrained() {
3511 let expr = parse_expr("? + 1");
3512 let mut ts = TypeSystem::new_with_prelude().unwrap();
3513 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3514 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3515 }
3516
3517 #[test]
3518 fn infer_hole_arithmetic_conflicting_annotation_failure() {
3519 let expr = parse_expr("let x : string = (? + 1) in x");
3520 let mut ts = TypeSystem::new_with_prelude().unwrap();
3521 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3522 assert!(matches!(err, TypeError::Unification(_, _)), "err={err:#?}");
3523 }
3524}