1use super::*;
2
3fn dedup_preds(preds: Vec<Predicate>) -> Vec<Predicate> {
4 let mut seen = HashSet::new();
5 let mut out = Vec::with_capacity(preds.len());
6 for pred in preds {
7 if seen.insert(pred.clone()) {
8 out.push(pred);
9 }
10 }
11 out
12}
13
14fn is_integral_primitive(typ: &Type) -> bool {
15 matches!(
16 typ.as_ref(),
17 TypeKind::Con(TypeConst {
18 builtin_id: Some(
19 BuiltinTypeId::U8
20 | BuiltinTypeId::U16
21 | BuiltinTypeId::U32
22 | BuiltinTypeId::U64
23 | BuiltinTypeId::I8
24 | BuiltinTypeId::I16
25 | BuiltinTypeId::I32
26 | BuiltinTypeId::I64
27 ),
28 ..
29 })
30 )
31}
32
33fn finalize_infer_for_public_api(
34 mut preds: Vec<Predicate>,
35 mut typ: Type,
36) -> Result<(Vec<Predicate>, Type), TypeError> {
37 let mut subst = Subst::new_sync();
38 for pred in &preds {
39 if pred.class.as_ref() == "Integral"
40 && let TypeKind::Var(tv) = pred.typ.as_ref()
41 {
42 subst = subst.insert(tv.id, Type::builtin(BuiltinTypeId::I32));
43 }
44 }
45
46 if !subst_is_empty(&subst) {
47 preds = dedup_preds(preds.apply(&subst));
48 typ = typ.apply(&subst);
49 }
50
51 for pred in &preds {
52 if pred.class.as_ref() != "Integral" {
53 continue;
54 }
55 if matches!(pred.typ.as_ref(), TypeKind::Var(_)) || is_integral_primitive(&pred.typ) {
56 continue;
57 }
58 return Err(TypeError::Unification("i32".into(), pred.typ.to_string()));
59 }
60
61 Ok((preds, typ))
62}
63
64#[derive(Clone, Debug)]
65struct KnownVariant {
66 adt: Symbol,
67 variant: Symbol,
68}
69
70type KnownVariants = HashMap<Symbol, KnownVariant>;
71
72fn apply_scheme_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> Scheme {
73 let preds = scheme
74 .preds
75 .iter()
76 .map(|pred| Predicate::new(pred.class.clone(), unifier.apply_type(&pred.typ)))
77 .collect();
78 let typ = unifier.apply_type(&scheme.typ);
79 Scheme::new(scheme.vars.clone(), preds, typ)
80}
81
82fn scheme_ftv_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> HashSet<TypeVarId> {
83 let mut ftv = unifier.apply_type(&scheme.typ).ftv();
84 for pred in &scheme.preds {
85 ftv.extend(unifier.apply_type(&pred.typ).ftv());
86 }
87 for var in &scheme.vars {
88 ftv.remove(&var.id);
89 }
90 ftv
91}
92
93fn env_ftv_with_unifier(env: &TypeEnv, unifier: &mut Unifier<'_>) -> HashSet<TypeVarId> {
94 let mut out = HashSet::new();
95 for (_name, schemes) in env.values.iter() {
96 for scheme in schemes {
97 out.extend(scheme_ftv_with_unifier(scheme, unifier));
98 }
99 }
100 out
101}
102
103fn generalize_with_unifier(
104 env: &TypeEnv,
105 preds: Vec<Predicate>,
106 typ: Type,
107 unifier: &mut Unifier<'_>,
108) -> Scheme {
109 let preds: Vec<Predicate> = preds
110 .into_iter()
111 .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
112 .collect();
113 let typ = unifier.apply_type(&typ);
114 let mut vars: Vec<TypeVar> = typ
115 .ftv()
116 .union(&preds.ftv())
117 .copied()
118 .collect::<HashSet<_>>()
119 .difference(&env_ftv_with_unifier(env, unifier))
120 .cloned()
121 .map(|id| TypeVar::new(id, None))
122 .collect();
123 vars.sort_by_key(|v| v.id);
124 Scheme::new(vars, preds, typ)
125}
126
127fn monomorphic_scheme_with_unifier(
128 preds: Vec<Predicate>,
129 typ: Type,
130 unifier: &mut Unifier<'_>,
131) -> Scheme {
132 let preds = dedup_preds(
133 preds
134 .into_iter()
135 .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
136 .collect(),
137 );
138 let typ = unifier.apply_type(&typ);
139 Scheme::new(vec![], preds, typ)
140}
141
142pub fn infer_typed(
143 type_system: &mut TypeSystem,
144 expr: &Expr,
145) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
146 infer_typed_inner(type_system, expr)
147}
148
149pub fn infer_typed_with_gas(
150 type_system: &mut TypeSystem,
151 expr: &Expr,
152 gas: &mut GasMeter,
153) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
154 let known = KnownVariants::new();
155 let mut unifier = Unifier::with_gas(gas, type_system.limits.max_infer_depth);
156 let (preds, t, typed) = infer_expr(
157 &mut unifier,
158 &mut type_system.supply,
159 &type_system.env,
160 &type_system.adts,
161 &known,
162 expr,
163 )
164 .map_err(|err| with_span(expr.span(), err))?;
165 let subst = unifier.into_subst();
166 let mut typed = typed.apply(&subst);
167 let mut preds = dedup_preds(preds.apply(&subst));
168 let mut t = t.apply(&subst);
169 let improve = improve_indexable(&preds)?;
170 if !subst_is_empty(&improve) {
171 typed = typed.apply(&improve);
172 preds = dedup_preds(preds.apply(&improve));
173 t = t.apply(&improve);
174 }
175 type_system.check_predicate_kinds(&preds)?;
176 Ok((typed, preds, t))
177}
178
179fn infer_typed_inner(
180 type_system: &mut TypeSystem,
181 expr: &Expr,
182) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
183 let known = KnownVariants::new();
184 let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
185 let (preds, t, typed) = infer_expr(
186 &mut unifier,
187 &mut type_system.supply,
188 &type_system.env,
189 &type_system.adts,
190 &known,
191 expr,
192 )
193 .map_err(|err| with_span(expr.span(), err))?;
194 let subst = unifier.into_subst();
195 let mut typed = typed.apply(&subst);
196 let mut preds = dedup_preds(preds.apply(&subst));
197 let mut t = t.apply(&subst);
198 let improve = improve_indexable(&preds)?;
199 if !subst_is_empty(&improve) {
200 typed = typed.apply(&improve);
201 preds = dedup_preds(preds.apply(&improve));
202 t = t.apply(&improve);
203 }
204 type_system.check_predicate_kinds(&preds)?;
205 Ok((typed, preds, t))
206}
207
208pub fn infer(
209 type_system: &mut TypeSystem,
210 expr: &Expr,
211) -> Result<(Vec<Predicate>, Type), TypeError> {
212 infer_inner(type_system, expr)
213}
214
215pub fn infer_with_gas(
216 type_system: &mut TypeSystem,
217 expr: &Expr,
218 gas: &mut GasMeter,
219) -> Result<(Vec<Predicate>, Type), TypeError> {
220 let known = KnownVariants::new();
221 let mut unifier = Unifier::with_gas(gas, type_system.limits.max_infer_depth);
222 let (preds, t) = infer_expr_type(
223 &mut unifier,
224 &mut type_system.supply,
225 &type_system.env,
226 &type_system.adts,
227 &known,
228 expr,
229 )
230 .map_err(|err| with_span(expr.span(), err))?;
231 let subst = unifier.into_subst();
232 let preds = dedup_preds(preds.apply(&subst));
233 let t = t.apply(&subst);
234 type_system.check_predicate_kinds(&preds)?;
235 finalize_infer_for_public_api(preds, t)
236}
237
238fn infer_inner(
239 type_system: &mut TypeSystem,
240 expr: &Expr,
241) -> Result<(Vec<Predicate>, Type), TypeError> {
242 let known = KnownVariants::new();
243 let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
244 let (preds, t) = infer_expr_type(
245 &mut unifier,
246 &mut type_system.supply,
247 &type_system.env,
248 &type_system.adts,
249 &known,
250 expr,
251 )
252 .map_err(|err| with_span(expr.span(), err))?;
253 let subst = unifier.into_subst();
254 let mut preds = dedup_preds(preds.apply(&subst));
255 let mut t = t.apply(&subst);
256 let improve = improve_indexable(&preds)?;
257 if !subst_is_empty(&improve) {
258 preds = dedup_preds(preds.apply(&improve));
259 t = t.apply(&improve);
260 }
261 type_system.check_predicate_kinds(&preds)?;
262 finalize_infer_for_public_api(preds, t)
263}
264
265fn improve_indexable(preds: &[Predicate]) -> Result<Subst, TypeError> {
266 let mut subst = Subst::new_sync();
267 loop {
268 let mut changed = false;
269 for pred in preds {
270 let pred = pred.apply(&subst);
271 if pred.class.as_ref() != "Indexable" {
272 continue;
273 }
274 let TypeKind::Tuple(parts) = pred.typ.as_ref() else {
275 continue;
276 };
277 if parts.len() != 2 {
278 continue;
279 }
280 let container = parts[0].clone();
281 let elem = parts[1].clone();
282 let s = indexable_elem_subst(&container, &elem)?;
283 if !subst_is_empty(&s) {
284 subst = compose_subst(s, subst);
285 changed = true;
286 }
287 }
288 if !changed {
289 break;
290 }
291 }
292 Ok(subst)
293}
294
295fn indexable_elem_subst(container: &Type, elem: &Type) -> Result<Subst, TypeError> {
296 match container.as_ref() {
297 TypeKind::App(head, arg) => match head.as_ref() {
298 TypeKind::Con(tc)
299 if matches!(
300 tc.builtin_id,
301 Some(BuiltinTypeId::List | BuiltinTypeId::Array)
302 ) =>
303 {
304 unify(elem, arg)
305 }
306 _ => Ok(Subst::new_sync()),
307 },
308 TypeKind::Tuple(elems) => {
309 if elems.is_empty() {
310 return Ok(Subst::new_sync());
311 }
312 let mut subst = Subst::new_sync();
313 let mut cur = elems[0].clone();
314 for ty in elems.iter().skip(1) {
315 let s_next = unify(&cur.apply(&subst), &ty.apply(&subst))?;
316 subst = compose_subst(s_next, subst);
317 cur = cur.apply(&subst);
318 }
319 let elem = elem.apply(&subst);
320 let s_elem = unify(&elem, &cur.apply(&subst))?;
321 Ok(compose_subst(s_elem, subst))
322 }
323 _ => Ok(Subst::new_sync()),
324 }
325}
326
327type LambdaChain<'a> = (
328 Vec<(Symbol, Option<TypeExpr>)>,
329 Vec<TypeConstraint>,
330 &'a Expr,
331);
332
333fn collect_lambda_chain<'a>(expr: &'a Expr) -> LambdaChain<'a> {
334 let mut params = Vec::new();
335 let mut constraints = Vec::new();
336 let mut cur = expr;
337 let mut seen_constraints = false;
338 while let Expr::Lam(_, _scope, param, ann, lam_constraints, body) = cur {
339 if !lam_constraints.is_empty() {
340 if seen_constraints {
341 break;
342 }
343 constraints = lam_constraints.clone();
344 seen_constraints = true;
345 }
346 params.push((param.name.clone(), ann.clone()));
347 cur = body.as_ref();
348 }
349 (params, constraints, cur)
350}
351
352fn collect_app_chain(expr: &Expr) -> (&Expr, Vec<&Expr>) {
353 let mut args = Vec::new();
354 let mut cur = expr;
355 while let Expr::App(_, f, x) = cur {
356 args.push(x.as_ref());
357 cur = f.as_ref();
358 }
359 args.reverse();
360 (cur, args)
361}
362
363fn narrow_overload_candidates(candidates: &[Type], arg_ty: &Type) -> Vec<Type> {
364 let mut out = Vec::new();
365 for candidate in candidates {
366 let Some((params, ret)) = decompose_fun(candidate, 1) else {
367 continue;
368 };
369 let param = ¶ms[0];
370 if let Ok(s) = unify(param, arg_ty) {
371 out.push(ret.apply(&s));
372 }
373 }
374 out
375}
376
377fn unary_app_arg(typ: &Type, ctor_name: &str) -> Option<Type> {
378 let TypeKind::App(head, arg) = typ.as_ref() else {
379 return None;
380 };
381 let TypeKind::Con(tc) = head.as_ref() else {
382 return None;
383 };
384 (tc.name.as_ref() == ctor_name && tc.arity == 1).then(|| arg.clone())
385}
386
387fn infer_app_arg_type(
388 unifier: &mut Unifier<'_>,
389 supply: &mut TypeVarSupply,
390 env: &TypeEnv,
391 adts: &HashMap<Symbol, AdtDecl>,
392 known: &KnownVariants,
393 arg_hint: Option<Type>,
394 arg: &Expr,
395) -> Result<(Vec<Predicate>, Type), TypeError> {
396 match (arg_hint, arg) {
397 (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
398 infer_record_update_type_with_hint(
399 unifier,
400 supply,
401 env,
402 adts,
403 known,
404 base.as_ref(),
405 updates,
406 &arg_hint,
407 )
408 }
409 (Some(arg_hint), Expr::Dict(_, kvs))
410 if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
411 {
412 let TypeKind::Record(fields) = arg_hint.as_ref() else {
413 unreachable!("guarded by matches!")
414 };
415 let expected: HashMap<_, _> =
416 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
417 let mut seen = HashSet::new();
418 let mut preds = Vec::new();
419 for (k, v) in kvs {
420 let expected_ty = expected
421 .get(k)
422 .ok_or_else(|| TypeError::UnknownField {
423 field: k.clone(),
424 typ: Type::record(fields.clone()).to_string(),
425 })?
426 .clone();
427 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
428 unifier.unify(&t1, &expected_ty)?;
429 preds.extend(p1);
430 seen.insert(k.clone());
431 }
432 for key in expected.keys() {
433 if !seen.contains(key.as_ref()) {
434 return Err(TypeError::UnknownField {
435 field: key.clone(),
436 typ: Type::record(fields.clone()).to_string(),
437 });
438 }
439 }
440 let record_ty = Type::record(
441 fields
442 .iter()
443 .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
444 .collect(),
445 );
446 Ok((preds, record_ty))
447 }
448 _ => infer_expr_type(unifier, supply, env, adts, known, arg),
449 }
450}
451
452fn infer_app_arg_typed(
453 unifier: &mut Unifier<'_>,
454 supply: &mut TypeVarSupply,
455 env: &TypeEnv,
456 adts: &HashMap<Symbol, AdtDecl>,
457 known: &KnownVariants,
458 arg_hint: Option<Type>,
459 arg: &Expr,
460) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
461 match (arg_hint, arg) {
462 (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
463 infer_record_update_typed_with_hint(
464 unifier,
465 supply,
466 env,
467 adts,
468 known,
469 base.as_ref(),
470 updates,
471 &arg_hint,
472 )
473 }
474 (Some(arg_hint), Expr::Dict(_, kvs))
475 if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
476 {
477 let TypeKind::Record(fields) = arg_hint.as_ref() else {
478 unreachable!("guarded by matches!")
479 };
480 let mut preds = Vec::new();
481 let mut typed_kvs = BTreeMap::new();
482 let expected: HashMap<_, _> =
483 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
484 for (k, v) in kvs {
485 let expected_ty = expected
486 .get(k)
487 .ok_or_else(|| TypeError::UnknownField {
488 field: k.clone(),
489 typ: Type::record(fields.clone()).to_string(),
490 })?
491 .clone();
492 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
493 unifier.unify(&t1, &expected_ty)?;
494 preds.extend(p1);
495 typed_kvs.insert(k.clone(), typed_v);
496 }
497 for key in expected.keys() {
498 if !typed_kvs.contains_key(key.as_ref()) {
499 return Err(TypeError::UnknownField {
500 field: key.clone(),
501 typ: Type::record(fields.clone()).to_string(),
502 });
503 }
504 }
505 let record_ty = Type::record(
506 fields
507 .iter()
508 .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
509 .collect(),
510 );
511 let typed = TypedExpr::new(record_ty.clone(), TypedExprKind::Dict(typed_kvs));
512 Ok((preds, record_ty, typed))
513 }
514 _ => infer_expr(unifier, supply, env, adts, known, arg),
515 }
516}
517
518#[allow(clippy::too_many_arguments)]
519fn infer_record_update_type_with_hint(
520 unifier: &mut Unifier<'_>,
521 supply: &mut TypeVarSupply,
522 env: &TypeEnv,
523 adts: &HashMap<Symbol, AdtDecl>,
524 known: &KnownVariants,
525 base: &Expr,
526 updates: &BTreeMap<Symbol, Arc<Expr>>,
527 hint_ty: &Type,
528) -> Result<(Vec<Predicate>, Type), TypeError> {
529 let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
530 unifier.unify(&t_base, hint_ty)?;
531 let base_ty = unifier.apply_type(&t_base);
532 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
533 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
534 let (result_ty, fields) = resolve_record_update(
535 unifier,
536 supply,
537 adts,
538 &base_ty,
539 known_variant,
540 &update_fields,
541 )?;
542 let expected: HashMap<_, _> = fields.into_iter().collect();
543
544 let mut preds = p_base;
545 for (k, v) in updates {
546 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
547 field: k.clone(),
548 typ: result_ty.to_string(),
549 })?;
550 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
551 unifier.unify(&t1, expected_ty)?;
552 preds.extend(p1);
553 }
554 Ok((preds, result_ty))
555}
556
557#[allow(clippy::too_many_arguments)]
558fn infer_record_update_typed_with_hint(
559 unifier: &mut Unifier<'_>,
560 supply: &mut TypeVarSupply,
561 env: &TypeEnv,
562 adts: &HashMap<Symbol, AdtDecl>,
563 known: &KnownVariants,
564 base: &Expr,
565 updates: &BTreeMap<Symbol, Arc<Expr>>,
566 hint_ty: &Type,
567) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
568 let (p_base, t_base, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
569 unifier.unify(&t_base, hint_ty)?;
570 let base_ty = unifier.apply_type(&t_base);
571 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
572 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
573 let (result_ty, fields) = resolve_record_update(
574 unifier,
575 supply,
576 adts,
577 &base_ty,
578 known_variant,
579 &update_fields,
580 )?;
581 let expected: HashMap<_, _> = fields.into_iter().collect();
582
583 let mut preds = p_base;
584 let mut typed_updates = BTreeMap::new();
585 for (k, v) in updates {
586 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
587 field: k.clone(),
588 typ: result_ty.to_string(),
589 })?;
590 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
591 unifier.unify(&t1, expected_ty)?;
592 preds.extend(p1);
593 typed_updates.insert(k.clone(), typed_v);
594 }
595
596 let typed = TypedExpr::new(
597 result_ty.clone(),
598 TypedExprKind::RecordUpdate {
599 base: Box::new(typed_base),
600 updates: typed_updates,
601 },
602 );
603 Ok((preds, result_ty, typed))
604}
605
606fn infer_expr_type(
607 unifier: &mut Unifier<'_>,
608 supply: &mut TypeVarSupply,
609 env: &TypeEnv,
610 adts: &HashMap<Symbol, AdtDecl>,
611 known: &KnownVariants,
612 expr: &Expr,
613) -> Result<(Vec<Predicate>, Type), TypeError> {
614 let span = *expr.span();
615 let res = unifier.with_infer_depth(span, |unifier| {
616 infer_expr_type_inner(unifier, supply, env, adts, known, expr)
617 });
618 res.map_err(|err| with_span(&span, err))
619}
620
621fn infer_expr_type_inner(
622 unifier: &mut Unifier<'_>,
623 supply: &mut TypeVarSupply,
624 env: &TypeEnv,
625 adts: &HashMap<Symbol, AdtDecl>,
626 known: &KnownVariants,
627 expr: &Expr,
628) -> Result<(Vec<Predicate>, Type), TypeError> {
629 unifier.charge_infer_node()?;
630 match expr {
631 Expr::Bool(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Bool))),
632 Expr::Uint(_, _) => {
633 let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
634 Ok((vec![Predicate::new("Integral", lit_ty.clone())], lit_ty))
635 }
636 Expr::Int(_, _) => {
637 let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
638 Ok((
639 vec![
640 Predicate::new("Integral", lit_ty.clone()),
641 Predicate::new("AdditiveGroup", lit_ty.clone()),
642 ],
643 lit_ty,
644 ))
645 }
646 Expr::Float(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::F32))),
647 Expr::String(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::String))),
648 Expr::Uuid(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Uuid))),
649 Expr::DateTime(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::DateTime))),
650 Expr::Hole(_) => {
651 let t = Type::var(supply.fresh(Some(sym("hole"))));
652 Ok((vec![], t))
653 }
654 Expr::Var(var) => {
655 let schemes = env
656 .lookup(&var.name)
657 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
658 if schemes.len() == 1 {
659 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
660 let (preds, t) = instantiate(&scheme, supply);
661 Ok((preds, t))
662 } else {
663 for scheme in schemes {
664 if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
665 return Err(TypeError::AmbiguousOverload(var.name.clone()));
666 }
667 }
668 let t = Type::var(supply.fresh(Some(var.name.clone())));
669 Ok((vec![], t))
670 }
671 }
672 Expr::Lam(..) => {
673 let (params, constraints, body) = collect_lambda_chain(expr);
674 let mut ann_vars = HashMap::new();
675 let mut param_tys = Vec::with_capacity(params.len());
676 for (name, ann) in ¶ms {
677 let param_ty = match ann {
678 Some(ann) => type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?,
679 None => Type::var(supply.fresh(Some(name.clone()))),
680 };
681 param_tys.push((name.clone(), param_ty));
682 }
683
684 let mut env1 = env.clone();
685 let mut known_body = known.clone();
686 for (name, param_ty) in ¶m_tys {
687 env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
688 known_body.remove(name);
689 }
690
691 let (mut preds, body_ty) =
692 infer_expr_type(unifier, supply, &env1, adts, &known_body, body)?;
693 let constraint_preds =
694 predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
695 preds.extend(constraint_preds);
696
697 let mut fun_ty = unifier.apply_type(&body_ty);
698 for (_, param_ty) in param_tys.iter().rev() {
699 fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
700 }
701 Ok((preds, fun_ty))
702 }
703 Expr::App(..) => {
704 let (head, args) = collect_app_chain(expr);
705 let (mut preds, mut func_ty) =
706 infer_expr_type(unifier, supply, env, adts, known, head)?;
707 let mut overload_name = None;
708 let mut overload_candidates = if let Expr::Var(var) = head {
709 if let Some(schemes) = env.lookup(&var.name) {
710 if schemes.len() <= 1 {
711 None
712 } else {
713 let mut candidates = Vec::new();
714 for scheme in schemes {
715 if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
716 return Err(TypeError::AmbiguousOverload(var.name.clone()));
717 }
718 let scheme = apply_scheme_with_unifier(scheme, unifier);
719 let (p, typ) = instantiate(&scheme, supply);
720 if !p.is_empty() {
721 return Err(TypeError::AmbiguousOverload(var.name.clone()));
722 }
723 candidates.push(typ);
724 }
725 overload_name = Some(var.name.clone());
726 Some(candidates)
727 }
728 } else {
729 None
730 }
731 } else {
732 None
733 };
734 for arg in args {
735 let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
736 TypeKind::Fun(arg, _) => Some(arg.clone()),
737 _ => None,
738 };
739 let (p_arg, arg_ty) =
740 infer_app_arg_type(unifier, supply, env, adts, known, arg_hint, arg)?;
741 let arg_ty = unifier.apply_type(&arg_ty);
742 if let Some(candidates) = overload_candidates.take() {
743 let candidates = candidates
744 .into_iter()
745 .map(|t| unifier.apply_type(&t))
746 .collect::<Vec<_>>();
747 let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
748 if narrowed.is_empty()
749 && let Some(name) = &overload_name
750 {
751 return Err(TypeError::AmbiguousOverload(name.clone()));
752 }
753 overload_candidates = Some(narrowed);
754 }
755 let res_ty = match overload_candidates.as_ref() {
756 Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
757 _ => Type::var(supply.fresh(Some("r".into()))),
758 };
759 unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
760 preds.extend(p_arg);
761 func_ty = match overload_candidates.as_ref() {
762 Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
763 _ => unifier.apply_type(&res_ty),
764 };
765 }
766 Ok((preds, func_ty))
767 }
768 Expr::Project(_, base, field) => {
769 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, base)?;
770 let base_ty = unifier.apply_type(&t1);
771 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
772 let field_ty =
773 resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
774 Ok((p1, field_ty))
775 }
776 Expr::RecordUpdate(_, base, updates) => {
777 let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
778 let base_ty = unifier.apply_type(&t_base);
779 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
780 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
781 let (result_ty, fields) = resolve_record_update(
782 unifier,
783 supply,
784 adts,
785 &base_ty,
786 known_variant,
787 &update_fields,
788 )?;
789 let expected: HashMap<_, _> = fields.into_iter().collect();
790
791 let mut preds = p_base;
792 for (k, v) in updates {
793 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
794 field: k.clone(),
795 typ: result_ty.to_string(),
796 })?;
797 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
798 unifier.unify(&t1, expected_ty)?;
799 preds.extend(p1);
800 }
801 Ok((preds, result_ty))
802 }
803 Expr::Let(..) => {
804 let mut bindings = Vec::new();
805 let mut cur = expr;
806 while let Expr::Let(_, v, ann, d, b) = cur {
807 bindings.push((v.clone(), ann.clone(), d.clone()));
808 cur = b.as_ref();
809 }
810
811 let mut env_cur = env.clone();
812 let mut known_cur = known.clone();
813 for (v, ann, d) in bindings {
814 let (p1, t1) = if let Some(ref ann_expr) = ann {
815 let mut ann_vars = HashMap::new();
816 let ann_ty =
817 type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
818 match d.as_ref() {
819 Expr::RecordUpdate(_, base, updates) => infer_record_update_type_with_hint(
820 unifier,
821 supply,
822 &env_cur,
823 adts,
824 &known_cur,
825 base.as_ref(),
826 updates,
827 &ann_ty,
828 )?,
829 _ => {
830 let (p1, t1) =
831 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?;
832 unifier.unify(&t1, &ann_ty)?;
833 (p1, t1)
834 }
835 }
836 } else {
837 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?
838 };
839 let def_ty = unifier.apply_type(&t1);
840 let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
841 monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
842 } else {
843 let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
844 reject_ambiguous_scheme(&scheme)?;
845 scheme
846 };
847 env_cur.extend(v.name.clone(), scheme);
848 if let Some(known_variant) =
849 known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
850 {
851 known_cur.insert(
852 v.name.clone(),
853 KnownVariant {
854 adt: known_variant.adt,
855 variant: known_variant.variant,
856 },
857 );
858 } else {
859 known_cur.remove(&v.name);
860 }
861 }
862
863 let (p_body, t_body) =
864 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, cur)?;
865 Ok((p_body, t_body))
866 }
867 Expr::LetRec(_, bindings, body) => {
868 let mut env_seed = env.clone();
869 let mut known_seed = known.clone();
870 let mut binding_tys = HashMap::new();
871 for (var, _ann, _def) in bindings {
872 let tv = Type::var(supply.fresh(Some(var.name.clone())));
873 env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
874 known_seed.remove(&var.name);
875 binding_tys.insert(var.name.clone(), tv);
876 }
877
878 let mut inferred = Vec::with_capacity(bindings.len());
879 for (var, ann, def) in bindings {
880 let (preds, def_ty) =
881 infer_expr_type(unifier, supply, &env_seed, adts, &known_seed, def)?;
882 if let Some(ann) = ann {
883 let mut ann_vars = HashMap::new();
884 let ann_ty = type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
885 unifier.unify(&def_ty, &ann_ty)?;
886 }
887 let binding_ty = binding_tys
888 .get(&var.name)
889 .cloned()
890 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
891 unifier.unify(&binding_ty, &def_ty)?;
892 let resolved_ty = unifier.apply_type(&binding_ty);
893
894 if let Some(known_variant) =
895 known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
896 {
897 known_seed.insert(
898 var.name.clone(),
899 KnownVariant {
900 adt: known_variant.adt,
901 variant: known_variant.variant,
902 },
903 );
904 } else {
905 known_seed.remove(&var.name);
906 }
907 inferred.push((var.name.clone(), preds, resolved_ty));
908 }
909
910 let mut env_body = env.clone();
911 for (name, preds, def_ty) in inferred {
912 let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
913 reject_ambiguous_scheme(&scheme)?;
914 env_body.extend(name, scheme);
915 }
916
917 let (p_body, t_body) =
918 infer_expr_type(unifier, supply, &env_body, adts, &known_seed, body)?;
919 Ok((p_body, t_body))
920 }
921 Expr::Ite(_, cond, then_expr, else_expr) => {
922 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, cond)?;
923 unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
924 let (p2, t2) = infer_expr_type(unifier, supply, env, adts, known, then_expr)?;
925 let (p3, t3) = infer_expr_type(unifier, supply, env, adts, known, else_expr)?;
926 unifier.unify(&t2, &t3)?;
927 let out_ty = unifier.apply_type(&t2);
928 let mut preds = p1;
929 preds.extend(p2);
930 preds.extend(p3);
931 Ok((preds, out_ty))
932 }
933 Expr::Tuple(_, elems) => {
934 let mut preds = Vec::new();
935 let mut types = Vec::new();
936 for elem in elems {
937 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
938 preds.extend(p1);
939 types.push(unifier.apply_type(&t1));
940 }
941 let tuple_ty = Type::tuple(types);
942 Ok((preds, tuple_ty))
943 }
944 Expr::List(_, elems) => {
945 let elem_tv = Type::var(supply.fresh(Some("a".into())));
946 let mut preds = Vec::new();
947 for elem in elems {
948 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
949 unifier.unify(&t1, &elem_tv)?;
950 preds.extend(p1);
951 }
952 let list_ty = Type::app(
953 Type::builtin(BuiltinTypeId::List),
954 unifier.apply_type(&elem_tv),
955 );
956 Ok((preds, list_ty))
957 }
958 Expr::Dict(_, kvs) => {
959 let elem_tv = Type::var(supply.fresh(Some("v".into())));
960 let mut preds = Vec::new();
961 for v in kvs.values() {
962 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
963 unifier.unify(&t1, &elem_tv)?;
964 preds.extend(p1);
965 }
966 let dict_ty = Type::app(
967 Type::builtin(BuiltinTypeId::Dict),
968 unifier.apply_type(&elem_tv),
969 );
970 Ok((preds, dict_ty))
971 }
972 Expr::Match(_, scrutinee, arms) => {
973 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, scrutinee.as_ref())?;
974 let mut preds = p1;
975 let res_ty = Type::var(supply.fresh(Some("match".into())));
976 let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
977
978 for (pat, expr) in arms {
979 let scrutinee_ty = unifier.apply_type(&t1);
980 let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
981 preds.extend(p_pat);
982
983 let mut env_arm = env.clone();
984 for (name, ty) in binds {
985 env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
986 }
987 let mut known_arm = known.clone();
988 if let Expr::Var(var) = scrutinee.as_ref() {
989 match pat {
990 Pattern::Named(_, name, _) => {
991 let name_sym = name.to_dotted_symbol();
992 if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
993 known_arm.insert(
994 var.name.clone(),
995 KnownVariant {
996 adt: adt.name.clone(),
997 variant: name_sym,
998 },
999 );
1000 } else {
1001 known_arm.remove(&var.name);
1002 }
1003 }
1004 _ => {
1005 known_arm.remove(&var.name);
1006 }
1007 }
1008 }
1009 let (p_expr, t_expr) =
1010 infer_expr_type(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1011 unifier.unify(&res_ty, &t_expr)?;
1012 preds.extend(p_expr);
1013 }
1014
1015 let scrutinee_ty = unifier.apply_type(&t1);
1016 check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1017 let out_ty = unifier.apply_type(&res_ty);
1018 Ok((preds, out_ty))
1019 }
1020 Expr::Ann(_, expr, ann) => {
1021 let ann_ty = type_from_annotation_expr(adts, ann)?;
1022 match expr.as_ref() {
1023 Expr::RecordUpdate(_, base, updates) => {
1024 let (preds, out_ty) = infer_record_update_type_with_hint(
1025 unifier,
1026 supply,
1027 env,
1028 adts,
1029 known,
1030 base.as_ref(),
1031 updates,
1032 &ann_ty,
1033 )?;
1034 Ok((preds, out_ty))
1035 }
1036 _ => {
1037 let (preds, expr_ty) =
1038 infer_expr_type(unifier, supply, env, adts, known, expr)?;
1039 unifier.unify(&expr_ty, &ann_ty)?;
1040 let out_ty = unifier.apply_type(&ann_ty);
1041 Ok((preds, out_ty))
1042 }
1043 }
1044 }
1045 }
1046}
1047
1048fn infer_expr(
1049 unifier: &mut Unifier<'_>,
1050 supply: &mut TypeVarSupply,
1051 env: &TypeEnv,
1052 adts: &HashMap<Symbol, AdtDecl>,
1053 known: &KnownVariants,
1054 expr: &Expr,
1055) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
1056 let span = *expr.span();
1057 let res = unifier.with_infer_depth(span, |unifier| {
1058 (|| {
1059 unifier.charge_infer_node()?;
1060 match expr {
1061 Expr::Bool(_, v) => {
1062 let t = Type::builtin(BuiltinTypeId::Bool);
1063 Ok((
1064 vec![],
1065 t.clone(),
1066 TypedExpr::new(t, TypedExprKind::Bool(*v)),
1067 ))
1068 }
1069 Expr::Uint(_, v) => {
1070 let t = Type::var(supply.fresh(Some(sym("n"))));
1071 Ok((
1072 vec![Predicate::new("Integral", t.clone())],
1073 t.clone(),
1074 TypedExpr::new(t, TypedExprKind::Uint(*v)),
1075 ))
1076 }
1077 Expr::Int(_, v) => {
1078 let t = Type::var(supply.fresh(Some(sym("n"))));
1079 Ok((
1080 vec![
1081 Predicate::new("Integral", t.clone()),
1082 Predicate::new("AdditiveGroup", t.clone()),
1083 ],
1084 t.clone(),
1085 TypedExpr::new(t, TypedExprKind::Int(*v)),
1086 ))
1087 }
1088 Expr::Float(_, v) => {
1089 let t = Type::builtin(BuiltinTypeId::F32);
1090 Ok((
1091 vec![],
1092 t.clone(),
1093 TypedExpr::new(t, TypedExprKind::Float(*v)),
1094 ))
1095 }
1096 Expr::String(_, v) => {
1097 let t = Type::builtin(BuiltinTypeId::String);
1098 Ok((
1099 vec![],
1100 t.clone(),
1101 TypedExpr::new(t, TypedExprKind::String(v.clone())),
1102 ))
1103 }
1104 Expr::Uuid(_, v) => {
1105 let t = Type::builtin(BuiltinTypeId::Uuid);
1106 Ok((
1107 vec![],
1108 t.clone(),
1109 TypedExpr::new(t, TypedExprKind::Uuid(*v)),
1110 ))
1111 }
1112 Expr::DateTime(_, v) => {
1113 let t = Type::builtin(BuiltinTypeId::DateTime);
1114 Ok((
1115 vec![],
1116 t.clone(),
1117 TypedExpr::new(t, TypedExprKind::DateTime(*v)),
1118 ))
1119 }
1120 Expr::Hole(_) => {
1121 let t = Type::var(supply.fresh(Some(sym("hole"))));
1122 Ok((vec![], t.clone(), TypedExpr::new(t, TypedExprKind::Hole)))
1123 }
1124 Expr::Var(var) => {
1125 let schemes = env
1126 .lookup(&var.name)
1127 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1128 if schemes.len() == 1 {
1129 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1130 let (preds, t) = instantiate(&scheme, supply);
1131 let typed = TypedExpr::new(
1132 t.clone(),
1133 TypedExprKind::Var {
1134 name: var.name.clone(),
1135 overloads: vec![],
1136 },
1137 );
1138 Ok((preds, t, typed))
1139 } else {
1140 let mut overloads = Vec::new();
1141 for scheme in schemes {
1142 if !scheme.preds.is_empty() {
1143 return Err(TypeError::AmbiguousOverload(var.name.clone()));
1144 }
1145
1146 let scheme = apply_scheme_with_unifier(scheme, unifier);
1147 let (preds, typ) = instantiate(&scheme, supply);
1148 if !preds.is_empty() {
1149 return Err(TypeError::AmbiguousOverload(var.name.clone()));
1150 }
1151 overloads.push(typ);
1152 }
1153 let t = Type::var(supply.fresh(Some(var.name.clone())));
1154 let typed = TypedExpr::new(
1155 t.clone(),
1156 TypedExprKind::Var {
1157 name: var.name.clone(),
1158 overloads,
1159 },
1160 );
1161 Ok((vec![], t, typed))
1162 }
1163 }
1164 Expr::Lam(..) => {
1165 let (params, constraints, body) = collect_lambda_chain(expr);
1166 let mut ann_vars = HashMap::new();
1167 let mut param_tys = Vec::with_capacity(params.len());
1168 for (name, ann) in ¶ms {
1169 let param_ty = match ann {
1170 Some(ann) => {
1171 type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?
1172 }
1173 None => Type::var(supply.fresh(Some(name.clone()))),
1174 };
1175 param_tys.push((name.clone(), param_ty));
1176 }
1177
1178 let mut env1 = env.clone();
1179 let mut known_body = known.clone();
1180 for (name, param_ty) in ¶m_tys {
1181 env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
1182 known_body.remove(name);
1183 }
1184
1185 let (mut preds, body_ty, typed_body) =
1186 infer_expr(unifier, supply, &env1, adts, &known_body, body)?;
1187 let constraint_preds =
1188 predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
1189 preds.extend(constraint_preds);
1190
1191 let mut typed = typed_body;
1192 let mut fun_ty = unifier.apply_type(&body_ty);
1193 for (name, param_ty) in param_tys.iter().rev() {
1194 fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
1195 typed = TypedExpr::new(
1196 fun_ty.clone(),
1197 TypedExprKind::Lam {
1198 param: name.clone(),
1199 body: Box::new(typed),
1200 },
1201 );
1202 }
1203
1204 Ok((preds, fun_ty, typed))
1205 }
1206 Expr::App(..) => {
1207 let (head, args) = collect_app_chain(expr);
1208 let (mut preds, mut func_ty, mut typed) =
1209 infer_expr(unifier, supply, env, adts, known, head)?;
1210 let mut overload_name = None;
1211 let mut overload_candidates = match &typed.kind {
1212 TypedExprKind::Var { name, overloads } if !overloads.is_empty() => {
1213 overload_name = Some(name.clone());
1214 Some(overloads.clone())
1215 }
1216 _ => None,
1217 };
1218 for arg in args {
1219 let expected_arg = match unifier.apply_type(&func_ty).as_ref() {
1220 TypeKind::Fun(arg, _) => Some(arg.clone()),
1221 _ => None,
1222 };
1223 let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
1224 TypeKind::Fun(arg, _) => Some(arg.clone()),
1225 _ => None,
1226 };
1227 let (p_arg, arg_ty, typed_arg) =
1228 infer_app_arg_typed(unifier, supply, env, adts, known, arg_hint, arg)?;
1229 let mut arg_ty = unifier.apply_type(&arg_ty);
1230 let mut typed_arg = typed_arg;
1231
1232 if let Some(expected_arg) = expected_arg {
1233 let expected_arg = unifier.apply_type(&expected_arg);
1234 if let (Some(expected_elem), Some(arg_elem)) = (
1235 unary_app_arg(&expected_arg, "Array"),
1236 unary_app_arg(&arg_ty, "List"),
1237 ) {
1238 unifier.unify(&expected_elem, &arg_elem)?;
1239 let elem_ty = unifier.apply_type(&expected_elem);
1240 let list_ty = Type::list(elem_ty.clone());
1241 let array_ty = Type::array(elem_ty);
1242 let coercion_ty = Type::fun(list_ty, array_ty.clone());
1243 let coercion_fn = TypedExpr::new(
1244 coercion_ty,
1245 TypedExprKind::Var {
1246 name: sym("prim_array_from_list"),
1247 overloads: vec![],
1248 },
1249 );
1250 typed_arg = TypedExpr::new(
1251 array_ty.clone(),
1252 TypedExprKind::App(Box::new(coercion_fn), Box::new(typed_arg)),
1253 );
1254 arg_ty = array_ty;
1255 }
1256 }
1257 if let Some(candidates) = overload_candidates.take() {
1258 let candidates = candidates
1259 .into_iter()
1260 .map(|t| unifier.apply_type(&t))
1261 .collect::<Vec<_>>();
1262 let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
1263 if narrowed.is_empty()
1264 && let Some(name) = &overload_name
1265 {
1266 return Err(TypeError::AmbiguousOverload(name.clone()));
1267 }
1268 overload_candidates = Some(narrowed);
1269 }
1270 let res_ty = match overload_candidates.as_ref() {
1271 Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
1272 _ => Type::var(supply.fresh(Some("r".into()))),
1273 };
1274 unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
1275 let result_ty = match overload_candidates.as_ref() {
1276 Some(candidates) if candidates.len() == 1 => {
1277 unifier.apply_type(&candidates[0])
1278 }
1279 _ => unifier.apply_type(&res_ty),
1280 };
1281 preds.extend(p_arg);
1282 typed = TypedExpr::new(
1283 result_ty.clone(),
1284 TypedExprKind::App(Box::new(typed), Box::new(typed_arg)),
1285 );
1286 func_ty = result_ty;
1287 }
1288 Ok((preds, func_ty, typed))
1289 }
1290 Expr::Project(_, base, field) => {
1291 let (p1, t1, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
1292 let base_ty = unifier.apply_type(&t1);
1293 let known_variant =
1294 known_variant_from_expr_with_known(base, &base_ty, adts, known);
1295 let field_ty =
1296 resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
1297 let typed = TypedExpr::new(
1298 field_ty.clone(),
1299 TypedExprKind::Project {
1300 expr: Box::new(typed_base),
1301 field: field.clone(),
1302 },
1303 );
1304 Ok((p1, field_ty, typed))
1305 }
1306 Expr::RecordUpdate(_, base, updates) => {
1307 let (p_base, t_base, typed_base) =
1308 infer_expr(unifier, supply, env, adts, known, base)?;
1309 let base_ty = unifier.apply_type(&t_base);
1310 let known_variant =
1311 known_variant_from_expr_with_known(base, &base_ty, adts, known);
1312 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
1313 let (result_ty, fields) = resolve_record_update(
1314 unifier,
1315 supply,
1316 adts,
1317 &base_ty,
1318 known_variant,
1319 &update_fields,
1320 )?;
1321 let expected: HashMap<_, _> = fields.into_iter().collect();
1322
1323 let mut preds = p_base;
1324 let mut typed_updates = BTreeMap::new();
1325 for (k, v) in updates {
1326 let expected_ty =
1327 expected.get(k).ok_or_else(|| TypeError::UnknownField {
1328 field: k.clone(),
1329 typ: result_ty.to_string(),
1330 })?;
1331 let (p1, t1, typed_v) =
1332 infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
1333 unifier.unify(&t1, expected_ty)?;
1334 preds.extend(p1);
1335 typed_updates.insert(k.clone(), typed_v);
1336 }
1337 let typed = TypedExpr::new(
1338 result_ty.clone(),
1339 TypedExprKind::RecordUpdate {
1340 base: Box::new(typed_base),
1341 updates: typed_updates,
1342 },
1343 );
1344 Ok((preds, result_ty, typed))
1345 }
1346 Expr::Let(..) => {
1347 let mut bindings = Vec::new();
1348 let mut cur = expr;
1349 while let Expr::Let(_, v, ann, d, b) = cur {
1350 bindings.push((v.clone(), ann.clone(), d.clone()));
1351 cur = b.as_ref();
1352 }
1353
1354 let mut env_cur = env.clone();
1355 let mut known_cur = known.clone();
1356 let mut typed_defs = Vec::new();
1357 for (v, ann, d) in bindings {
1358 let (p1, t1, typed_def) = if let Some(ref ann_expr) = ann {
1359 let mut ann_vars = HashMap::new();
1360 let ann_ty = type_from_annotation_expr_vars(
1361 adts,
1362 ann_expr,
1363 &mut ann_vars,
1364 supply,
1365 )?;
1366 match d.as_ref() {
1367 Expr::RecordUpdate(_, base, updates) => {
1368 infer_record_update_typed_with_hint(
1369 unifier,
1370 supply,
1371 &env_cur,
1372 adts,
1373 &known_cur,
1374 base.as_ref(),
1375 updates,
1376 &ann_ty,
1377 )?
1378 }
1379 _ => {
1380 let (p1, t1, typed_def) = infer_expr(
1381 unifier, supply, &env_cur, adts, &known_cur, &d,
1382 )?;
1383 unifier.unify(&t1, &ann_ty)?;
1384 (p1, t1, typed_def)
1385 }
1386 }
1387 } else {
1388 infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?
1389 };
1390 let def_ty = unifier.apply_type(&t1);
1391 let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
1392 monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
1393 } else {
1394 let scheme =
1395 generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
1396 reject_ambiguous_scheme(&scheme)?;
1397 scheme
1398 };
1399 env_cur.extend(v.name.clone(), scheme);
1400 if let Some(known_variant) =
1401 known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
1402 {
1403 known_cur.insert(
1404 v.name.clone(),
1405 KnownVariant {
1406 adt: known_variant.adt,
1407 variant: known_variant.variant,
1408 },
1409 );
1410 } else {
1411 known_cur.remove(&v.name);
1412 }
1413 typed_defs.push((v.name.clone(), typed_def));
1414 }
1415
1416 let (p_body, t_body, typed_body) =
1417 infer_expr(unifier, supply, &env_cur, adts, &known_cur, cur)?;
1418
1419 let mut typed = typed_body;
1420 for (name, def) in typed_defs.into_iter().rev() {
1421 typed = TypedExpr::new(
1422 t_body.clone(),
1423 TypedExprKind::Let {
1424 name,
1425 def: Box::new(def),
1426 body: Box::new(typed),
1427 },
1428 );
1429 }
1430 Ok((p_body, t_body, typed))
1431 }
1432 Expr::LetRec(_, bindings, body) => {
1433 let mut env_seed = env.clone();
1434 let mut known_seed = known.clone();
1435 let mut binding_tys = HashMap::new();
1436 for (var, _ann, _def) in bindings {
1437 let tv = Type::var(supply.fresh(Some(var.name.clone())));
1438 env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
1439 known_seed.remove(&var.name);
1440 binding_tys.insert(var.name.clone(), tv);
1441 }
1442
1443 let mut inferred_defs = Vec::with_capacity(bindings.len());
1444 for (var, ann, def) in bindings {
1445 let (preds, def_ty, typed_def) =
1446 infer_expr(unifier, supply, &env_seed, adts, &known_seed, def)?;
1447 if let Some(ann) = ann {
1448 let mut ann_vars = HashMap::new();
1449 let ann_ty =
1450 type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
1451 unifier.unify(&def_ty, &ann_ty)?;
1452 }
1453 let binding_ty = binding_tys
1454 .get(&var.name)
1455 .cloned()
1456 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1457 unifier.unify(&binding_ty, &def_ty)?;
1458 let resolved_ty = unifier.apply_type(&binding_ty);
1459
1460 if let Some(known_variant) =
1461 known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
1462 {
1463 known_seed.insert(
1464 var.name.clone(),
1465 KnownVariant {
1466 adt: known_variant.adt,
1467 variant: known_variant.variant,
1468 },
1469 );
1470 } else {
1471 known_seed.remove(&var.name);
1472 }
1473 inferred_defs.push((var.name.clone(), preds, resolved_ty, typed_def));
1474 }
1475
1476 let mut env_body = env.clone();
1477 let mut typed_bindings = Vec::with_capacity(inferred_defs.len());
1478 for (name, preds, def_ty, typed_def) in inferred_defs {
1479 let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
1480 reject_ambiguous_scheme(&scheme)?;
1481 env_body.extend(name.clone(), scheme);
1482 typed_bindings.push((name, typed_def));
1483 }
1484
1485 let (p_body, t_body, typed_body) =
1486 infer_expr(unifier, supply, &env_body, adts, &known_seed, body)?;
1487 let typed = TypedExpr::new(
1488 t_body.clone(),
1489 TypedExprKind::LetRec {
1490 bindings: typed_bindings,
1491 body: Box::new(typed_body),
1492 },
1493 );
1494 Ok((p_body, t_body, typed))
1495 }
1496 Expr::Ite(_, cond, then_expr, else_expr) => {
1497 let (p1, t1, typed_cond) = infer_expr(unifier, supply, env, adts, known, cond)?;
1498 unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
1499 let (p2, t2, typed_then) =
1500 infer_expr(unifier, supply, env, adts, known, then_expr)?;
1501 let (p3, t3, typed_else) =
1502 infer_expr(unifier, supply, env, adts, known, else_expr)?;
1503 unifier.unify(&t2, &t3)?;
1504 let out_ty = unifier.apply_type(&t2);
1505 let mut preds = p1;
1506 preds.extend(p2);
1507 preds.extend(p3);
1508 let typed = TypedExpr::new(
1509 out_ty.clone(),
1510 TypedExprKind::Ite {
1511 cond: Box::new(typed_cond),
1512 then_expr: Box::new(typed_then),
1513 else_expr: Box::new(typed_else),
1514 },
1515 );
1516 Ok((preds, out_ty, typed))
1517 }
1518 Expr::Tuple(_, elems) => {
1519 let mut preds = Vec::new();
1520 let mut types = Vec::new();
1521 let mut typed_elems = Vec::new();
1522 for elem in elems {
1523 let (p1, t1, typed_elem) =
1524 infer_expr(unifier, supply, env, adts, known, elem)?;
1525 preds.extend(p1);
1526 types.push(unifier.apply_type(&t1));
1527 typed_elems.push(typed_elem);
1528 }
1529 let tuple_ty = Type::tuple(types);
1530 let typed = TypedExpr::new(tuple_ty.clone(), TypedExprKind::Tuple(typed_elems));
1531 Ok((preds, tuple_ty, typed))
1532 }
1533 Expr::List(_, elems) => {
1534 let elem_tv = Type::var(supply.fresh(Some("a".into())));
1535 let mut preds = Vec::new();
1536 let mut typed_elems = Vec::new();
1537 for elem in elems {
1538 let (p1, t1, typed_elem) =
1539 infer_expr(unifier, supply, env, adts, known, elem)?;
1540 unifier.unify(&t1, &elem_tv)?;
1541 preds.extend(p1);
1542 typed_elems.push(typed_elem);
1543 }
1544 let list_ty = Type::app(
1545 Type::builtin(BuiltinTypeId::List),
1546 unifier.apply_type(&elem_tv),
1547 );
1548 let typed = TypedExpr::new(list_ty.clone(), TypedExprKind::List(typed_elems));
1549 Ok((preds, list_ty, typed))
1550 }
1551 Expr::Dict(_, kvs) => {
1552 let elem_tv = Type::var(supply.fresh(Some("v".into())));
1553 let mut preds = Vec::new();
1554 let mut typed_kvs = BTreeMap::new();
1555 for (k, v) in kvs {
1556 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v)?;
1557 unifier.unify(&t1, &elem_tv)?;
1558 preds.extend(p1);
1559 typed_kvs.insert(k.clone(), typed_v);
1560 }
1561 let dict_ty = Type::app(
1562 Type::builtin(BuiltinTypeId::Dict),
1563 unifier.apply_type(&elem_tv),
1564 );
1565 let typed = TypedExpr::new(dict_ty.clone(), TypedExprKind::Dict(typed_kvs));
1566 Ok((preds, dict_ty, typed))
1567 }
1568 Expr::Match(_, scrutinee, arms) => {
1569 let (p1, t1, typed_scrutinee) =
1570 infer_expr(unifier, supply, env, adts, known, scrutinee)?;
1571 let mut preds = p1;
1572 let mut typed_arms = Vec::new();
1573 let res_ty = Type::var(supply.fresh(Some("match".into())));
1574 let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
1575
1576 for (pat, expr) in arms {
1577 let scrutinee_ty = unifier.apply_type(&t1);
1578 let (p_pat, binds) =
1579 infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
1580 preds.extend(p_pat);
1581
1582 let mut env_arm = env.clone();
1583 for (name, ty) in binds {
1584 env_arm
1585 .extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1586 }
1587 let mut known_arm = known.clone();
1588 if let Expr::Var(var) = scrutinee.as_ref() {
1589 match pat {
1590 Pattern::Named(_, name, _) => {
1591 let name_sym = name.to_dotted_symbol();
1592 if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1593 known_arm.insert(
1594 var.name.clone(),
1595 KnownVariant {
1596 adt: adt.name.clone(),
1597 variant: name_sym,
1598 },
1599 );
1600 } else {
1601 known_arm.remove(&var.name);
1602 }
1603 }
1604 _ => {
1605 known_arm.remove(&var.name);
1606 }
1607 }
1608 }
1609 let (p_expr, t_expr, typed_expr) =
1610 infer_expr(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1611 unifier.unify(&res_ty, &t_expr)?;
1612 preds.extend(p_expr);
1613 typed_arms.push((pat.clone(), typed_expr));
1614 }
1615
1616 let scrutinee_ty = unifier.apply_type(&t1);
1617 check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1618 let out_ty = unifier.apply_type(&res_ty);
1619 let typed = TypedExpr::new(
1620 out_ty.clone(),
1621 TypedExprKind::Match {
1622 scrutinee: Box::new(typed_scrutinee),
1623 arms: typed_arms,
1624 },
1625 );
1626 Ok((preds, out_ty, typed))
1627 }
1628 Expr::Ann(_, expr, ann) => {
1629 let ann_ty = type_from_annotation_expr(adts, ann)?;
1630 match expr.as_ref() {
1631 Expr::RecordUpdate(_, base, updates) => {
1632 infer_record_update_typed_with_hint(
1633 unifier,
1634 supply,
1635 env,
1636 adts,
1637 known,
1638 base.as_ref(),
1639 updates,
1640 &ann_ty,
1641 )
1642 }
1643 _ => {
1644 let (preds, expr_ty, typed_expr) =
1645 infer_expr(unifier, supply, env, adts, known, expr)?;
1646 unifier.unify(&expr_ty, &ann_ty)?;
1647 let out_ty = unifier.apply_type(&ann_ty);
1648 Ok((preds, out_ty, typed_expr))
1649 }
1650 }
1651 }
1652 }
1653 })()
1654 });
1655 res.map_err(|err| with_span(&span, err))
1656}
1657
1658fn ctor_lookup<'a>(
1659 adts: &'a HashMap<Symbol, AdtDecl>,
1660 name: &Symbol,
1661) -> Option<(&'a AdtDecl, &'a AdtVariant)> {
1662 let mut found = None;
1663 for adt in adts.values() {
1664 if let Some(variant) = adt.variants.iter().find(|v| &v.name == name) {
1665 if found.is_some() {
1666 return None;
1667 }
1668 found = Some((adt, variant));
1669 }
1670 }
1671 found
1672}
1673
1674fn record_fields(variant: &AdtVariant) -> Option<&[(Symbol, Type)]> {
1675 if variant.args.len() != 1 {
1676 return None;
1677 }
1678 match variant.args[0].as_ref() {
1679 TypeKind::Record(fields) => Some(fields),
1680 _ => None,
1681 }
1682}
1683
1684fn instantiate_variant_fields(
1685 adt: &AdtDecl,
1686 variant: &AdtVariant,
1687 supply: &mut TypeVarSupply,
1688) -> Option<(Type, Vec<(Symbol, Type)>)> {
1689 let fields = record_fields(variant)?;
1690 let mut subst = Subst::new_sync();
1691 for param in &adt.params {
1692 let fresh = Type::var(supply.fresh(param.var.name.clone()));
1693 subst = subst.insert(param.var.id, fresh);
1694 }
1695 let result_ty = adt.result_type().apply(&subst);
1696 let fields = fields
1697 .iter()
1698 .map(|(name, ty)| (name.clone(), ty.apply(&subst)))
1699 .collect();
1700 Some((result_ty, fields))
1701}
1702
1703fn known_variant_from_expr(
1704 expr: &Expr,
1705 expr_ty: &Type,
1706 adts: &HashMap<Symbol, AdtDecl>,
1707) -> Option<KnownVariant> {
1708 let mut expr = expr;
1709 while let Expr::Ann(_, inner, _) = expr {
1710 expr = inner.as_ref();
1711 }
1712 if matches!(expr_ty.as_ref(), TypeKind::Fun(..)) {
1713 return None;
1714 }
1715 let ctor = match expr {
1716 Expr::App(_, f, _) => match f.as_ref() {
1717 Expr::Var(var) => var.name.clone(),
1718 _ => return None,
1719 },
1720 _ => return None,
1721 };
1722 let (adt, variant) = ctor_lookup(adts, &ctor)?;
1723 record_fields(variant)?;
1724 Some(KnownVariant {
1725 adt: adt.name.clone(),
1726 variant: variant.name.clone(),
1727 })
1728}
1729
1730fn known_variant_from_expr_with_known(
1731 expr: &Expr,
1732 expr_ty: &Type,
1733 adts: &HashMap<Symbol, AdtDecl>,
1734 known: &KnownVariants,
1735) -> Option<KnownVariant> {
1736 let mut expr = expr;
1737 while let Expr::Ann(_, inner, _) = expr {
1738 expr = inner.as_ref();
1739 }
1740 match expr {
1741 Expr::Var(var) => known.get(&var.name).cloned(),
1742 Expr::RecordUpdate(_, base, _) => {
1743 known_variant_from_expr_with_known(base.as_ref(), expr_ty, adts, known)
1744 }
1745 _ => known_variant_from_expr(expr, expr_ty, adts),
1746 }
1747}
1748
1749fn select_record_variant<'a, F>(
1750 adts: &'a HashMap<Symbol, AdtDecl>,
1751 base_ty: &Type,
1752 known_variant: Option<KnownVariant>,
1753 field_for_errors: &Symbol,
1754 matches_fields: F,
1755) -> Result<(&'a AdtDecl, &'a AdtVariant), TypeError>
1756where
1757 F: Fn(&[(Symbol, Type)]) -> bool,
1758{
1759 if let Some(info) = known_variant {
1760 let adt = adts
1761 .get(&info.adt)
1762 .ok_or_else(|| TypeError::UnknownTypeName(info.adt.clone()))?;
1763 let variant = adt
1764 .variants
1765 .iter()
1766 .find(|v| v.name == info.variant)
1767 .ok_or_else(|| TypeError::UnknownField {
1768 field: field_for_errors.clone(),
1769 typ: base_ty.to_string(),
1770 })?;
1771 return Ok((adt, variant));
1772 }
1773
1774 if let Some(adt_name) = type_head_name(base_ty) {
1775 let adt = adts.get(adt_name).ok_or_else(|| TypeError::UnknownField {
1776 field: field_for_errors.clone(),
1777 typ: base_ty.to_string(),
1778 })?;
1779 if adt.variants.len() == 1 {
1780 return Ok((adt, &adt.variants[0]));
1781 }
1782 return Err(TypeError::FieldNotKnown {
1783 field: field_for_errors.clone(),
1784 typ: base_ty.to_string(),
1785 });
1786 }
1787
1788 if matches!(base_ty.as_ref(), TypeKind::Var(_)) {
1789 let mut candidates = Vec::new();
1790 for adt in adts.values() {
1791 if adt.variants.len() != 1 {
1792 continue;
1793 }
1794 let variant = &adt.variants[0];
1795 let Some(fields) = record_fields(variant) else {
1796 continue;
1797 };
1798 if matches_fields(fields) {
1799 candidates.push((adt, variant));
1800 }
1801 }
1802 if candidates.len() == 1 {
1803 return Ok(candidates.remove(0));
1804 }
1805 if candidates.is_empty() {
1806 return Err(TypeError::UnknownField {
1807 field: field_for_errors.clone(),
1808 typ: base_ty.to_string(),
1809 });
1810 }
1811 return Err(TypeError::FieldNotKnown {
1812 field: field_for_errors.clone(),
1813 typ: base_ty.to_string(),
1814 });
1815 }
1816
1817 Err(TypeError::UnknownField {
1818 field: field_for_errors.clone(),
1819 typ: base_ty.to_string(),
1820 })
1821}
1822
1823fn resolve_record_update(
1824 unifier: &mut Unifier<'_>,
1825 supply: &mut TypeVarSupply,
1826 adts: &HashMap<Symbol, AdtDecl>,
1827 base_ty: &Type,
1828 known_variant: Option<KnownVariant>,
1829 update_fields: &[Symbol],
1830) -> Result<(Type, Vec<(Symbol, Type)>), TypeError> {
1831 if let TypeKind::Record(fields) = base_ty.as_ref() {
1832 return Ok((base_ty.clone(), fields.clone()));
1833 }
1834
1835 let field_for_errors = update_fields.first().cloned().unwrap_or_else(|| sym("_"));
1836
1837 let (adt, variant) =
1838 select_record_variant(adts, base_ty, known_variant, &field_for_errors, |fields| {
1839 update_fields
1840 .iter()
1841 .all(|field| fields.iter().any(|(name, _)| name == field))
1842 })?;
1843
1844 let (result_ty, fields) =
1845 instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1846 TypeError::UnknownField {
1847 field: field_for_errors.clone(),
1848 typ: base_ty.to_string(),
1849 }
1850 })?;
1851
1852 for field in update_fields {
1853 if fields.iter().all(|(name, _)| name != field) {
1854 return Err(TypeError::UnknownField {
1855 field: field.clone(),
1856 typ: base_ty.to_string(),
1857 });
1858 }
1859 }
1860
1861 unifier.unify(base_ty, &result_ty)?;
1862 let result_ty = unifier.apply_type(&result_ty);
1863 let fields = fields
1864 .into_iter()
1865 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
1866 .collect();
1867 Ok((result_ty, fields))
1868}
1869
1870fn resolve_projection(
1871 unifier: &mut Unifier<'_>,
1872 supply: &mut TypeVarSupply,
1873 adts: &HashMap<Symbol, AdtDecl>,
1874 base_ty: &Type,
1875 known_variant: Option<KnownVariant>,
1876 field: &Symbol,
1877) -> Result<Type, TypeError> {
1878 if let Ok(index) = field.as_ref().parse::<usize>() {
1879 let elem_ty = match base_ty.as_ref() {
1880 TypeKind::Tuple(elems) => {
1881 elems
1882 .get(index)
1883 .cloned()
1884 .ok_or_else(|| TypeError::UnknownField {
1885 field: field.clone(),
1886 typ: base_ty.to_string(),
1887 })?
1888 }
1889 TypeKind::Var(_) => {
1890 let mut elems = Vec::with_capacity(index + 1);
1891 for _ in 0..=index {
1892 elems.push(Type::var(supply.fresh(Some(sym("t")))));
1893 }
1894 let tuple_ty = Type::tuple(elems.clone());
1895 unifier.unify(base_ty, &tuple_ty)?;
1896 elems[index].clone()
1897 }
1898 _ => {
1899 return Err(TypeError::UnknownField {
1900 field: field.clone(),
1901 typ: base_ty.to_string(),
1902 });
1903 }
1904 };
1905 return Ok(unifier.apply_type(&elem_ty));
1906 }
1907
1908 let (adt, variant) = select_record_variant(adts, base_ty, known_variant, field, |fields| {
1909 fields.iter().any(|(name, _)| name == field)
1910 })?;
1911
1912 let (result_ty, fields) =
1913 instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1914 TypeError::UnknownField {
1915 field: field.clone(),
1916 typ: base_ty.to_string(),
1917 }
1918 })?;
1919 let field_ty = fields
1920 .iter()
1921 .find(|(name, _)| name == field)
1922 .map(|(_, ty)| ty.clone())
1923 .ok_or_else(|| TypeError::UnknownField {
1924 field: field.clone(),
1925 typ: base_ty.to_string(),
1926 })?;
1927 unifier.unify(base_ty, &result_ty)?;
1928 Ok(unifier.apply_type(&field_ty))
1929}
1930
1931fn decompose_fun(typ: &Type, arity: usize) -> Option<(Vec<Type>, Type)> {
1932 let mut args = Vec::with_capacity(arity);
1933 let mut cur = typ.clone();
1934 for _ in 0..arity {
1935 match cur.as_ref() {
1936 TypeKind::Fun(a, b) => {
1937 args.push(a.clone());
1938 cur = b.clone();
1939 }
1940 _ => return None,
1941 }
1942 }
1943 Some((args, cur))
1944}
1945
1946type InferPatternResult = (Vec<Predicate>, Vec<(Symbol, Type)>);
1947
1948fn infer_pattern(
1949 unifier: &mut Unifier<'_>,
1950 supply: &mut TypeVarSupply,
1951 env: &TypeEnv,
1952 pat: &Pattern,
1953 scrutinee_ty: &Type,
1954) -> Result<InferPatternResult, TypeError> {
1955 let span = *pat.span();
1956 let res = (|| {
1957 unifier.charge_infer_node()?;
1958 match pat {
1959 Pattern::Wildcard(..) => Ok((vec![], vec![])),
1960 Pattern::Var(var) => Ok((
1961 vec![],
1962 vec![(var.name.clone(), unifier.apply_type(scrutinee_ty))],
1963 )),
1964 Pattern::Named(_, name, ps) => {
1965 let ctor_name = name.to_dotted_symbol();
1966 let schemes = env
1967 .lookup(&ctor_name)
1968 .ok_or_else(|| TypeError::UnknownVar(ctor_name.clone()))?;
1969 if schemes.len() != 1 {
1970 return Err(TypeError::AmbiguousOverload(ctor_name));
1971 }
1972 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1973 let (preds, ctor_ty) = instantiate(&scheme, supply);
1974 let (arg_tys, res_ty) = decompose_fun(&ctor_ty, ps.len())
1975 .ok_or(TypeError::UnsupportedExpr("pattern constructor"))?;
1976 unifier.unify(&res_ty, scrutinee_ty)?;
1977 let mut all_preds = preds;
1978 let mut bindings = Vec::new();
1979 for (p, arg_ty) in ps.iter().zip(arg_tys.iter()) {
1980 let arg_ty = unifier.apply_type(arg_ty);
1981 let (p1, binds1) = infer_pattern(unifier, supply, env, p, &arg_ty)?;
1982 all_preds.extend(p1);
1983 bindings.extend(binds1);
1984 }
1985 let bindings = bindings
1986 .into_iter()
1987 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
1988 .collect();
1989 Ok((all_preds, bindings))
1990 }
1991 Pattern::List(_, ps) => {
1992 let elem_tv = Type::var(supply.fresh(Some("a".into())));
1993 let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
1994 unifier.unify(scrutinee_ty, &list_ty)?;
1995 let mut preds = Vec::new();
1996 let mut bindings = Vec::new();
1997 for p in ps {
1998 let elem_ty = unifier.apply_type(&elem_tv);
1999 let (p1, binds1) = infer_pattern(unifier, supply, env, p, &elem_ty)?;
2000 preds.extend(p1);
2001 bindings.extend(binds1);
2002 }
2003 let bindings = bindings
2004 .into_iter()
2005 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2006 .collect();
2007 Ok((preds, bindings))
2008 }
2009 Pattern::Cons(_, head, tail) => {
2010 let elem_tv = Type::var(supply.fresh(Some("a".into())));
2011 let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2012 unifier.unify(scrutinee_ty, &list_ty)?;
2013 let mut preds = Vec::new();
2014 let mut bindings = Vec::new();
2015
2016 let head_ty = unifier.apply_type(&elem_tv);
2017 let (p1, binds1) = infer_pattern(unifier, supply, env, head, &head_ty)?;
2018 preds.extend(p1);
2019 bindings.extend(binds1);
2020
2021 let tail_ty = Type::app(
2022 Type::builtin(BuiltinTypeId::List),
2023 unifier.apply_type(&elem_tv),
2024 );
2025 let (p2, binds2) = infer_pattern(unifier, supply, env, tail, &tail_ty)?;
2026 preds.extend(p2);
2027 bindings.extend(binds2);
2028
2029 let bindings = bindings
2030 .into_iter()
2031 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2032 .collect();
2033 Ok((preds, bindings))
2034 }
2035 Pattern::Tuple(_, elems) => {
2036 let mut elem_tys: Vec<Type> = (0..elems.len())
2037 .map(|i| Type::var(supply.fresh(Some(format!("t{i}").into()))))
2038 .collect();
2039 let expected = Type::tuple(elem_tys.clone());
2040 unifier.unify(scrutinee_ty, &expected)?;
2041 elem_tys = elem_tys
2042 .into_iter()
2043 .map(|t| unifier.apply_type(&t))
2044 .collect();
2045
2046 let mut preds = Vec::new();
2047 let mut bindings = Vec::new();
2048 for (p, ty) in elems.iter().zip(elem_tys.iter()) {
2049 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, p, ty)?;
2050 preds.extend(p_preds);
2051 bindings.extend(p_binds);
2052 }
2053 let bindings = bindings
2054 .into_iter()
2055 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2056 .collect();
2057 Ok((preds, bindings))
2058 }
2059 Pattern::Dict(_, fields) => {
2060 if let TypeKind::Record(ty_fields) = scrutinee_ty.as_ref() {
2061 let mut preds = Vec::new();
2062 let mut bindings = Vec::new();
2063 for (key, pat) in fields {
2064 let ty = ty_fields
2065 .iter()
2066 .find(|(name, _)| name == key)
2067 .map(|(_, ty)| unifier.apply_type(ty))
2068 .ok_or_else(|| TypeError::UnknownField {
2069 field: key.clone(),
2070 typ: scrutinee_ty.to_string(),
2071 })?;
2072 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &ty)?;
2073 preds.extend(p_preds);
2074 bindings.extend(p_binds);
2075 }
2076 let bindings = bindings
2077 .into_iter()
2078 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2079 .collect();
2080 Ok((preds, bindings))
2081 } else {
2082 let elem_tv = Type::var(supply.fresh(Some("v".into())));
2083 let dict_ty = Type::app(Type::builtin(BuiltinTypeId::Dict), elem_tv.clone());
2084 unifier.unify(scrutinee_ty, &dict_ty)?;
2085 let elem_ty = unifier.apply_type(&elem_tv);
2086
2087 let mut preds = Vec::new();
2088 let mut bindings = Vec::new();
2089 for (_key, pat) in fields {
2090 let (p_preds, p_binds) =
2091 infer_pattern(unifier, supply, env, pat, &elem_ty)?;
2092 preds.extend(p_preds);
2093 bindings.extend(p_binds);
2094 }
2095 let bindings = bindings
2096 .into_iter()
2097 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2098 .collect();
2099 Ok((preds, bindings))
2100 }
2101 }
2102 }
2103 })();
2104 res.map_err(|err| with_span(&span, err))
2105}
2106
2107fn type_head_name(typ: &Type) -> Option<&Symbol> {
2108 let mut cur = typ;
2109 while let TypeKind::App(head, _) = cur.as_ref() {
2110 cur = head;
2111 }
2112 match cur.as_ref() {
2113 TypeKind::Con(tc) => Some(&tc.name),
2114 _ => None,
2115 }
2116}
2117
2118fn adt_name_from_patterns(adts: &HashMap<Symbol, AdtDecl>, patterns: &[Pattern]) -> Option<Symbol> {
2119 let mut candidate: Option<Symbol> = None;
2120 for pat in patterns {
2121 let next = match pat {
2122 Pattern::Named(_, name, _) => {
2123 let name_sym = name.to_dotted_symbol();
2124 ctor_lookup(adts, &name_sym).map(|(adt, _)| adt.name.clone())
2125 }
2126 Pattern::List(..) | Pattern::Cons(..) => Some(sym("List")),
2127 _ => None,
2128 };
2129 if let Some(next) = next {
2130 match &candidate {
2131 None => candidate = Some(next),
2132 Some(prev) if *prev == next => {}
2133 Some(_) => return None,
2134 }
2135 }
2136 }
2137 candidate
2138}
2139
2140fn check_match_exhaustive(
2141 adts: &HashMap<Symbol, AdtDecl>,
2142 scrutinee_ty: &Type,
2143 patterns: &[Pattern],
2144) -> Result<(), TypeError> {
2145 if patterns
2146 .iter()
2147 .any(|p| matches!(p, Pattern::Wildcard(..) | Pattern::Var(_)))
2148 {
2149 return Ok(());
2150 }
2151 let adt_name = match type_head_name(scrutinee_ty).cloned() {
2152 Some(name) => name,
2153 None => match adt_name_from_patterns(adts, patterns) {
2154 Some(name) => name,
2155 None => return Ok(()),
2156 },
2157 };
2158 let adt = match adts.get(&adt_name) {
2159 Some(adt) => adt,
2160 None => return Ok(()),
2161 };
2162 let ctor_names: HashSet<Symbol> = adt.variants.iter().map(|v| v.name.clone()).collect();
2163 if ctor_names.is_empty() {
2164 return Ok(());
2165 }
2166 let mut covered = HashSet::new();
2167 for pat in patterns {
2168 match pat {
2169 Pattern::Named(_, name, _) => {
2170 let name_sym = name.to_dotted_symbol();
2171 if ctor_names.contains(&name_sym) {
2172 covered.insert(name_sym);
2173 }
2174 }
2175 Pattern::List(_, elems) if adt_name.as_ref() == "List" && elems.is_empty() => {
2176 covered.insert(sym("Empty"));
2177 }
2178 Pattern::Cons(..) if adt_name.as_ref() == "List" => {
2179 covered.insert(sym("Cons"));
2180 }
2181 _ => {}
2182 }
2183 }
2184 let mut missing: Vec<Symbol> = ctor_names.difference(&covered).cloned().collect();
2185 if missing.is_empty() {
2186 return Ok(());
2187 }
2188 missing.sort();
2189 Err(TypeError::NonExhaustiveMatch {
2190 typ: scrutinee_ty.to_string(),
2191 missing,
2192 })
2193}
2194
2195#[cfg(test)]
2196mod tests {
2197 use super::*;
2198 use rexlang_lexer::Token;
2199 use rexlang_parser::Parser;
2200 use rexlang_util::{GasCosts, GasMeter};
2201
2202 fn tvar(id: TypeVarId, name: &str) -> Type {
2203 Type::var(TypeVar::new(id, Some(sym(name))))
2204 }
2205
2206 fn dict_of(elem: Type) -> Type {
2207 Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
2208 }
2209
2210 #[test]
2211 fn unify_simple() {
2212 let t1 = Type::fun(tvar(0, "a"), Type::builtin(BuiltinTypeId::U32));
2213 let t2 = Type::fun(Type::builtin(BuiltinTypeId::U16), tvar(1, "b"));
2214 let subst = unify(&t1, &t2).unwrap();
2215 assert_eq!(subst.get(&0), Some(&Type::builtin(BuiltinTypeId::U16)));
2216 assert_eq!(subst.get(&1), Some(&Type::builtin(BuiltinTypeId::U32)));
2217 }
2218
2219 #[test]
2220 fn occurs_check_blocks_infinite_type() {
2221 let tv = TypeVar::new(0, Some(sym("a")));
2222 let t = Type::fun(Type::var(tv.clone()), Type::builtin(BuiltinTypeId::U8));
2223 let err = bind(&tv, &t).unwrap_err();
2224 assert!(matches!(err, TypeError::Occurs(_, _)));
2225 }
2226
2227 #[test]
2228 fn instantiate_and_generalize_round_trip() {
2229 let mut supply = TypeVarSupply::new();
2230 let a = Type::var(supply.fresh(Some(sym("a"))));
2231 let scheme = generalize(&TypeEnv::new(), vec![], Type::fun(a.clone(), a.clone()));
2232 let (preds, inst) = instantiate(&scheme, &mut supply);
2233 assert!(preds.is_empty());
2234 if let TypeKind::Fun(l, r) = inst.as_ref() {
2235 match (l.as_ref(), r.as_ref()) {
2236 (TypeKind::Var(_), TypeKind::Var(_)) => {}
2237 _ => panic!("expected polymorphic identity"),
2238 }
2239 } else {
2240 panic!("expected function type");
2241 }
2242 }
2243
2244 #[test]
2245 fn entail_superclasses() {
2246 let ts = TypeSystem::new_with_prelude().unwrap();
2247 let pred = Predicate::new("Semiring", Type::builtin(BuiltinTypeId::I32));
2248 let given = [Predicate::new(
2249 "AdditiveGroup",
2250 Type::builtin(BuiltinTypeId::I32),
2251 )];
2252 assert!(entails(&ts.classes, &given, &pred).unwrap());
2253 }
2254
2255 #[test]
2256 fn entail_instances() {
2257 let ts = TypeSystem::new_with_prelude().unwrap();
2258 let pred = Predicate::new("Field", Type::builtin(BuiltinTypeId::F32));
2259 assert!(entails(&ts.classes, &[], &pred).unwrap());
2260
2261 let pred_fail = Predicate::new("Field", Type::builtin(BuiltinTypeId::U32));
2262 assert!(!entails(&ts.classes, &[], &pred_fail).unwrap());
2263 }
2264
2265 #[test]
2266 fn prelude_injects_functions() {
2267 let ts = TypeSystem::new_with_prelude().unwrap();
2268 let minus = ts.env.lookup(&sym("-")).expect("minus in env");
2269 let div = ts.env.lookup(&sym("/")).expect("div in env");
2270 assert_eq!(minus.len(), 1);
2271 assert_eq!(div.len(), 1);
2272 let minus = &minus[0];
2273 let div = &div[0];
2274 assert_eq!(minus.preds.len(), 1);
2275 assert_eq!(minus.vars.len(), 1);
2276 assert_eq!(div.preds.len(), 1);
2277 assert_eq!(div.vars.len(), 1);
2278 }
2279
2280 #[test]
2281 fn adt_constructors_are_present() {
2282 let ts = TypeSystem::new_with_prelude().unwrap();
2283 assert!(ts.env.lookup(&sym("Empty")).is_some());
2284 assert!(ts.env.lookup(&sym("Cons")).is_some());
2285 assert!(ts.env.lookup(&sym("Ok")).is_some());
2286 assert!(ts.env.lookup(&sym("Err")).is_some());
2287 assert!(ts.env.lookup(&sym("Some")).is_some());
2288 assert!(ts.env.lookup(&sym("None")).is_some());
2289 }
2290
2291 fn parse_expr(code: &str) -> std::sync::Arc<rexlang_ast::expr::Expr> {
2292 let mut parser = Parser::new(Token::tokenize(code).unwrap());
2293 parser.parse_program(&mut GasMeter::default()).unwrap().expr
2294 }
2295
2296 fn parse_program(code: &str) -> rexlang_ast::expr::Program {
2297 let mut parser = Parser::new(Token::tokenize(code).unwrap());
2298 parser.parse_program(&mut GasMeter::default()).unwrap()
2299 }
2300
2301 #[test]
2302 fn infer_deep_list_does_not_overflow() {
2303 const N: usize = 40;
2304 let mut code = String::new();
2305 code.push_str("let xs = ");
2306 for _ in 0..N {
2307 code.push_str("Cons 0 (");
2308 }
2309 code.push_str("Empty");
2310 for _ in 0..N {
2311 code.push(')');
2312 }
2313 code.push_str(" in xs");
2314
2315 let parse_handle = std::thread::Builder::new()
2316 .name("infer_deep_list_parse".into())
2317 .stack_size(128 * 1024 * 1024)
2318 .spawn(move || {
2319 let tokens = Token::tokenize(&code).unwrap();
2320 let mut parser = Parser::new(tokens);
2321 parser.parse_program(&mut GasMeter::default())
2322 })
2323 .unwrap();
2324 let program = parse_handle.join().unwrap().unwrap();
2325 let expr = program.expr;
2326 let mut ts = TypeSystem::new_with_prelude().unwrap();
2327 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2328 assert_eq!(
2329 ty,
2330 Type::app(
2331 Type::builtin(BuiltinTypeId::List),
2332 Type::builtin(BuiltinTypeId::I32)
2333 )
2334 );
2335 }
2336
2337 #[test]
2338 fn collect_adts_in_types_finds_nested_unique_adts() {
2339 let foo = Type::user_con("Foo", 1);
2340 let bar = Type::user_con("Bar", 0);
2341 let ty = Type::fun(
2342 Type::app(
2343 Type::builtin(BuiltinTypeId::List),
2344 Type::app(foo.clone(), tvar(0, "a")),
2345 ),
2346 Type::tuple(vec![
2347 Type::app(foo.clone(), Type::builtin(BuiltinTypeId::I32)),
2348 bar.clone(),
2349 ]),
2350 );
2351
2352 let adts = collect_adts_in_types(vec![ty]).unwrap();
2353 assert_eq!(adts, vec![foo, bar]);
2354 }
2355
2356 #[test]
2357 fn collect_adts_in_types_rejects_conflicting_names() {
2358 let arity1 = Type::user_con("Thing", 1);
2359 let arity2 = Type::user_con("Thing", 2);
2360
2361 let err = collect_adts_in_types(vec![arity1.clone(), arity2.clone()]).unwrap_err();
2362 assert_eq!(err.conflicts.len(), 1);
2363 let conflict = &err.conflicts[0];
2364 assert_eq!(conflict.name, sym("Thing"));
2365 assert_eq!(conflict.definitions, vec![arity1, arity2]);
2366 }
2367
2368 #[test]
2369 fn infer_depth_limit_is_enforced() {
2370 const N: usize = 40;
2371 let mut code = String::new();
2372 code.push_str("let xs = ");
2373 for _ in 0..N {
2374 code.push_str("Cons 0 (");
2375 }
2376 code.push_str("Empty");
2377 for _ in 0..N {
2378 code.push(')');
2379 }
2380 code.push_str(" in xs");
2381
2382 let program = parse_program(&code);
2383 let mut ts = TypeSystem::new_with_prelude().unwrap();
2384 ts.set_limits(TypeSystemLimits {
2385 max_infer_depth: Some(8),
2386 });
2387
2388 let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
2389 assert!(
2390 err.to_string().contains("maximum inference depth exceeded"),
2391 "expected a max-depth inference error, got: {err:?}"
2392 );
2393 }
2394
2395 #[test]
2396 fn declare_fn_injects_scheme_for_use_sites() {
2397 let program = parse_program(
2398 r#"
2399 declare fn id x: a -> a
2400 id 1
2401 "#,
2402 );
2403 let mut ts = TypeSystem::new_with_prelude().unwrap();
2404 ts.register_decls(&program.decls).unwrap();
2405 let (preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2406 assert!(
2407 preds.is_empty()
2408 || preds.iter().all(|p| p.class.as_ref() == "Integral"
2409 && p.typ == Type::builtin(BuiltinTypeId::I32))
2410 );
2411 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2412 }
2413
2414 #[test]
2415 fn declare_fn_is_noop_when_matching_existing_scheme() {
2416 let mut ts = TypeSystem::new_with_prelude().unwrap();
2417 ts.add_value(
2418 "foo",
2419 Scheme::new(
2420 vec![],
2421 vec![],
2422 Type::fun(
2423 Type::builtin(BuiltinTypeId::I32),
2424 Type::builtin(BuiltinTypeId::I32),
2425 ),
2426 ),
2427 );
2428
2429 let program = parse_program(
2430 r#"
2431 declare fn foo x: i32 -> i32
2432 0
2433 "#,
2434 );
2435 let rexlang_ast::expr::Decl::DeclareFn(fd) = &program.decls[0] else {
2436 panic!("expected declare fn decl");
2437 };
2438 ts.inject_declare_fn_decl(fd).unwrap();
2439 }
2440
2441 #[test]
2442 fn unit_type_parses_and_infers() {
2443 let program = parse_program(
2444 r#"
2445 fn unit_id x: () -> () = x
2446 unit_id ()
2447 "#,
2448 );
2449 let mut ts = TypeSystem::new_with_prelude().unwrap();
2450 ts.register_decls(&program.decls).unwrap();
2451 let (preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2452 assert!(preds.is_empty());
2453 assert_eq!(ty, Type::tuple(vec![]));
2454 }
2455
2456 fn strip_span(mut err: TypeError) -> TypeError {
2457 while let TypeError::Spanned { error, .. } = err {
2458 err = *error;
2459 }
2460 err
2461 }
2462
2463 #[test]
2464 fn type_errors_include_span() {
2465 let expr = parse_expr("missing");
2466 let mut ts = TypeSystem::new_with_prelude().unwrap();
2467 let err = infer(&mut ts, expr.as_ref()).unwrap_err();
2468 match err {
2469 TypeError::Spanned { span, error } => {
2470 assert_ne!(span, Span::default());
2471 assert!(matches!(
2472 *error,
2473 TypeError::UnknownVar(name) if name.as_ref() == "missing"
2474 ));
2475 }
2476 other => panic!("expected spanned error, got {other:?}"),
2477 }
2478 }
2479
2480 #[test]
2481 fn infer_with_gas_rejects_out_of_budget() {
2482 let expr = parse_expr("1");
2483 let mut ts = TypeSystem::new_with_prelude().unwrap();
2484 let mut gas = GasMeter::new(
2485 Some(0),
2486 GasCosts {
2487 infer_node: 1,
2488 unify_step: 0,
2489 ..GasCosts::sensible_defaults()
2490 },
2491 );
2492 let err = infer_with_gas(&mut ts, expr.as_ref(), &mut gas).unwrap_err();
2493 assert!(matches!(strip_span(err), TypeError::OutOfGas(..)));
2494 }
2495
2496 #[test]
2497 fn reject_user_redefinition_of_primitive_type_name() {
2498 let program = parse_program("type i32 = I32Wrap i32");
2499 let mut ts = TypeSystem::new_with_prelude().unwrap();
2500 let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2501 panic!("expected type decl");
2502 };
2503 let err = ts.register_type_decl(decl).unwrap_err();
2504 assert!(matches!(
2505 err,
2506 TypeError::ReservedTypeName(name) if name.as_ref() == "i32"
2507 ));
2508 }
2509
2510 #[test]
2511 fn reject_user_redefinition_of_prelude_adt_name() {
2512 let program = parse_program("type Result e a = Nope e a");
2513 let mut ts = TypeSystem::new_with_prelude().unwrap();
2514 let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2515 panic!("expected type decl");
2516 };
2517 let err = ts.register_type_decl(decl).unwrap_err();
2518 assert!(matches!(
2519 err,
2520 TypeError::ReservedTypeName(name) if name.as_ref() == "Result"
2521 ));
2522 }
2523
2524 #[test]
2525 fn reject_user_redefinition_of_promise_type_name() {
2526 let program = parse_program("type Promise a = PromiseWrap a");
2527 let mut ts = TypeSystem::new_with_prelude().unwrap();
2528 let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2529 panic!("expected type decl");
2530 };
2531 let err = ts.register_type_decl(decl).unwrap_err();
2532 assert!(matches!(
2533 err,
2534 TypeError::ReservedTypeName(name) if name.as_ref() == "Promise"
2535 ));
2536 }
2537
2538 #[test]
2539 fn infer_polymorphic_id_tuple() {
2540 let expr = parse_expr(
2541 r#"
2542 let
2543 id = \x -> x
2544 in
2545 id (id 420, id 6.9, id "str")
2546 "#,
2547 );
2548 let mut ts = TypeSystem::new_with_prelude().unwrap();
2549 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2550 let expected = Type::tuple(vec![
2551 Type::builtin(BuiltinTypeId::I32),
2552 Type::builtin(BuiltinTypeId::F32),
2553 Type::builtin(BuiltinTypeId::String),
2554 ]);
2555 assert_eq!(ty, expected);
2556 }
2557
2558 #[test]
2559 fn infer_type_annotation_ok() {
2560 let expr = parse_expr("let x: i32 = 42 in x");
2561 let mut ts = TypeSystem::new_with_prelude().unwrap();
2562 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2563 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2564 }
2565
2566 #[test]
2567 fn infer_type_annotation_lambda_param() {
2568 let expr = parse_expr("\\ (a : f32) -> a");
2569 let mut ts = TypeSystem::new_with_prelude().unwrap();
2570 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2571 assert_eq!(
2572 ty,
2573 Type::fun(
2574 Type::builtin(BuiltinTypeId::F32),
2575 Type::builtin(BuiltinTypeId::F32)
2576 )
2577 );
2578 }
2579
2580 #[test]
2581 fn infer_type_annotation_is_alias() {
2582 let expr = parse_expr("\"hi\" is str");
2583 let mut ts = TypeSystem::new_with_prelude().unwrap();
2584 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2585 assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
2586 }
2587
2588 #[test]
2589 fn infer_type_annotation_with_promise_constructor() {
2590 let expr = parse_expr("\\(x: Promise i32) -> x");
2591 let mut ts = TypeSystem::new_with_prelude().unwrap();
2592 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2593 let promise_i32 = Type::promise(Type::builtin(BuiltinTypeId::I32));
2594 assert_eq!(ty, Type::fun(promise_i32.clone(), promise_i32));
2595 }
2596
2597 #[test]
2598 fn infer_type_annotation_mismatch_error() {
2599 let expr = parse_expr("let x: i32 = 3.14 in x");
2600 let mut ts = TypeSystem::new_with_prelude().unwrap();
2601 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
2602 assert!(matches!(err, TypeError::Unification(_, _)));
2603 }
2604
2605 #[test]
2606 fn infer_project_single_variant_let() {
2607 let program = parse_program(
2608 r#"
2609 type MyADT = MyVariant1 { field1: i32, field2: f32 }
2610 let
2611 x = MyVariant1 { field1 = 1, field2 = 2.0 }
2612 in
2613 (x.field1, x.field2)
2614 "#,
2615 );
2616 let mut ts = TypeSystem::new_with_prelude().unwrap();
2617 for decl in &program.decls {
2618 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2619 ts.register_type_decl(decl).unwrap();
2620 }
2621 }
2622 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2623 let expected = Type::tuple(vec![
2624 Type::builtin(BuiltinTypeId::I32),
2625 Type::builtin(BuiltinTypeId::F32),
2626 ]);
2627 assert_eq!(ty, expected);
2628 }
2629
2630 #[test]
2631 fn infer_project_known_variant_let() {
2632 let program = parse_program(
2633 r#"
2634 type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
2635 let
2636 x = MyVariant1 { field1 = 1, field2 = 2.0 }
2637 in
2638 x.field1
2639 "#,
2640 );
2641 let mut ts = TypeSystem::new_with_prelude().unwrap();
2642 for decl in &program.decls {
2643 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2644 ts.register_type_decl(decl).unwrap();
2645 }
2646 }
2647 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2648 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2649 }
2650
2651 #[test]
2652 fn infer_project_unknown_variant_error() {
2653 let program = parse_program(
2654 r#"
2655 type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
2656 let
2657 x = MyVariant2 1 2.0
2658 in
2659 x.field1
2660 "#,
2661 );
2662 let mut ts = TypeSystem::new_with_prelude().unwrap();
2663 for decl in &program.decls {
2664 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2665 ts.register_type_decl(decl).unwrap();
2666 }
2667 }
2668 let err = strip_span(infer(&mut ts, program.expr.as_ref()).unwrap_err());
2669 assert!(matches!(err, TypeError::FieldNotKnown { .. }));
2670 }
2671
2672 #[test]
2673 fn infer_project_lambda_param_single_variant() {
2674 let program = parse_program(
2675 r#"
2676 type Boxed = Boxed { value: i32 }
2677 let
2678 f = \x -> x.value
2679 in
2680 f (Boxed { value = 1 })
2681 "#,
2682 );
2683 let mut ts = TypeSystem::new_with_prelude().unwrap();
2684 for decl in &program.decls {
2685 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2686 ts.register_type_decl(decl).unwrap();
2687 }
2688 }
2689 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2690 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2691 }
2692
2693 #[test]
2694 fn infer_project_in_match_arm() {
2695 let program = parse_program(
2696 r#"
2697 type MyADT = MyVariant1 { field1: i32 } | MyVariant2 i32
2698 let
2699 x = MyVariant1 { field1 = 1 }
2700 in
2701 match x
2702 when MyVariant1 { field1 } -> x.field1
2703 when MyVariant2 _ -> 0
2704 "#,
2705 );
2706 let mut ts = TypeSystem::new_with_prelude().unwrap();
2707 for decl in &program.decls {
2708 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2709 ts.register_type_decl(decl).unwrap();
2710 }
2711 }
2712 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2713 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2714 }
2715
2716 #[test]
2717 fn infer_nested_let_lambda_match_option() {
2718 let expr = parse_expr(
2719 r#"
2720 let
2721 choose = \flag a b -> if flag then a else b,
2722 build = \flag ->
2723 let
2724 pick = choose flag,
2725 val = pick 1 2
2726 in
2727 Some val
2728 in
2729 match (build true)
2730 when Some x -> x
2731 when None -> 0
2732 "#,
2733 );
2734 let mut ts = TypeSystem::new_with_prelude().unwrap();
2735 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2736 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2737 }
2738
2739 #[test]
2740 fn infer_polymorphic_apply_in_tuple() {
2741 let expr = parse_expr(
2742 r#"
2743 let
2744 apply = \f x -> f x,
2745 id = \x -> x,
2746 wrap = \x -> (x, x)
2747 in
2748 (apply id 1, apply id "hi", apply wrap true)
2749 "#,
2750 );
2751 let mut ts = TypeSystem::new_with_prelude().unwrap();
2752 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2753 let expected = Type::tuple(vec![
2754 Type::builtin(BuiltinTypeId::I32),
2755 Type::builtin(BuiltinTypeId::String),
2756 Type::tuple(vec![
2757 Type::builtin(BuiltinTypeId::Bool),
2758 Type::builtin(BuiltinTypeId::Bool),
2759 ]),
2760 ]);
2761 assert_eq!(ty, expected);
2762 }
2763
2764 #[test]
2765 fn infer_nested_result_option_match() {
2766 let expr = parse_expr(
2767 r#"
2768 let
2769 unwrap = \x ->
2770 match x
2771 when Ok (Some v) -> v
2772 when Ok None -> 0
2773 when Err _ -> 0
2774 in
2775 unwrap (Ok (Some 5))
2776 "#,
2777 );
2778 let mut ts = TypeSystem::new_with_prelude().unwrap();
2779 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2780 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2781 }
2782
2783 #[test]
2784 fn infer_head_or_list_match() {
2785 let expr = parse_expr(
2786 r#"
2787 let
2788 head_or = \fallback xs ->
2789 match xs
2790 when [] -> fallback
2791 when x::xs -> x
2792 in
2793 (head_or 0 [1, 2, 3], head_or 0 [])
2794 "#,
2795 );
2796 let mut ts = TypeSystem::new_with_prelude().unwrap();
2797 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2798 let expected = Type::tuple(vec![
2799 Type::builtin(BuiltinTypeId::I32),
2800 Type::builtin(BuiltinTypeId::I32),
2801 ]);
2802 assert_eq!(ty, expected);
2803 }
2804
2805 #[test]
2806 fn infer_head_or_list_match_cons_constructor_form() {
2807 let expr = parse_expr(
2808 r#"
2809 let
2810 head_or = \fallback xs ->
2811 match xs
2812 when [] -> fallback
2813 when Cons x xs1 -> x
2814 in
2815 (head_or 0 (Cons 1 (Cons 2 Empty)), head_or 0 Empty)
2816 "#,
2817 );
2818 let mut ts = TypeSystem::new_with_prelude().unwrap();
2819 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2820 let expected = Type::tuple(vec![
2821 Type::builtin(BuiltinTypeId::I32),
2822 Type::builtin(BuiltinTypeId::I32),
2823 ]);
2824 assert_eq!(ty, expected);
2825 }
2826
2827 #[test]
2828 fn infer_record_pattern_in_lambda() {
2829 let program = parse_program(
2830 r#"
2831 type Pair = Pair { left: i32, right: i32 }
2832 let
2833 sum = \p ->
2834 match p
2835 when Pair { left, right } -> left + right
2836 in
2837 sum (Pair { left = 1, right = 2 })
2838 "#,
2839 );
2840 let mut ts = TypeSystem::new_with_prelude().unwrap();
2841 for decl in &program.decls {
2842 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2843 ts.register_type_decl(decl).unwrap();
2844 }
2845 }
2846 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2847 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2848 }
2849
2850 #[test]
2851 fn infer_fn_decl_simple() {
2852 let program = parse_program(
2853 r#"
2854 fn add (x: i32, y: i32) -> i32 = x + y
2855 add 1 2
2856 "#,
2857 );
2858 let mut ts = TypeSystem::new_with_prelude().unwrap();
2859 let expr = program.expr_with_fns();
2860 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2861 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2862 }
2863
2864 #[test]
2865 fn infer_fn_decl_signature_form() {
2866 let program = parse_program(
2867 r#"
2868 fn add : i32 -> i32 -> i32 = \x y -> x + y
2869 add 1 2
2870 "#,
2871 );
2872 let mut ts = TypeSystem::new_with_prelude().unwrap();
2873 let expr = program.expr_with_fns();
2874 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2875 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2876 }
2877
2878 #[test]
2879 fn infer_fn_decl_polymorphic_where_constraints() {
2880 let program = parse_program(
2881 r#"
2882 fn my_add (x: a, y: a) -> a where AdditiveMonoid a = x + y
2883 (my_add 1 2, my_add 1.0 2.0)
2884 "#,
2885 );
2886 let mut ts = TypeSystem::new_with_prelude().unwrap();
2887 let expr = program.expr_with_fns();
2888 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2889 assert_eq!(
2890 ty,
2891 Type::tuple(vec![
2892 Type::builtin(BuiltinTypeId::I32),
2893 Type::builtin(BuiltinTypeId::F32)
2894 ])
2895 );
2896 }
2897
2898 #[test]
2899 fn infer_additive_monoid_constraint() {
2900 let expr = parse_expr("\\x y -> x + y");
2901 let mut ts = TypeSystem::new_with_prelude().unwrap();
2902 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2903 assert_eq!(preds.len(), 1);
2904 assert_eq!(preds[0].class.as_ref(), "AdditiveMonoid");
2905
2906 if let TypeKind::Fun(a, rest) = ty.as_ref()
2907 && let TypeKind::Fun(b, c) = rest.as_ref()
2908 {
2909 assert_eq!(a.as_ref(), b.as_ref());
2910 assert_eq!(b.as_ref(), c.as_ref());
2911 assert_eq!(preds[0].typ, a.clone());
2912 return;
2913 }
2914 panic!("expected a -> a -> a");
2915 }
2916
2917 #[test]
2918 fn infer_multiplicative_monoid_constraint() {
2919 let expr = parse_expr("\\x y -> x * y");
2920 let mut ts = TypeSystem::new_with_prelude().unwrap();
2921 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2922 assert_eq!(preds.len(), 1);
2923 assert_eq!(preds[0].class.as_ref(), "MultiplicativeMonoid");
2924
2925 if let TypeKind::Fun(a, rest) = ty.as_ref()
2926 && let TypeKind::Fun(b, c) = rest.as_ref()
2927 {
2928 assert_eq!(a.as_ref(), b.as_ref());
2929 assert_eq!(b.as_ref(), c.as_ref());
2930 assert_eq!(preds[0].typ, a.clone());
2931 return;
2932 }
2933 panic!("expected a -> a -> a");
2934 }
2935
2936 #[test]
2937 fn infer_additive_group_constraint() {
2938 let expr = parse_expr("\\x y -> x - y");
2939 let mut ts = TypeSystem::new_with_prelude().unwrap();
2940 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2941 assert_eq!(preds.len(), 1);
2942 assert_eq!(preds[0].class.as_ref(), "AdditiveGroup");
2943
2944 if let TypeKind::Fun(a, rest) = ty.as_ref()
2945 && let TypeKind::Fun(b, c) = rest.as_ref()
2946 {
2947 assert_eq!(a.as_ref(), b.as_ref());
2948 assert_eq!(b.as_ref(), c.as_ref());
2949 assert_eq!(preds[0].typ, a.clone());
2950 return;
2951 }
2952 panic!("expected a -> a -> a");
2953 }
2954
2955 #[test]
2956 fn infer_integral_constraint() {
2957 let expr = parse_expr("\\x y -> x % y");
2958 let mut ts = TypeSystem::new_with_prelude().unwrap();
2959 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2960 assert_eq!(preds.len(), 1);
2961 assert_eq!(preds[0].class.as_ref(), "Integral");
2962
2963 if let TypeKind::Fun(a, rest) = ty.as_ref()
2964 && let TypeKind::Fun(b, c) = rest.as_ref()
2965 {
2966 assert_eq!(a.as_ref(), b.as_ref());
2967 assert_eq!(b.as_ref(), c.as_ref());
2968 assert_eq!(preds[0].typ, a.clone());
2969 return;
2970 }
2971 panic!("expected a -> a -> a");
2972 }
2973
2974 #[test]
2975 fn infer_literal_addition_defaults() {
2976 let expr = parse_expr("1 + 2");
2977 let mut ts = TypeSystem::new_with_prelude().unwrap();
2978 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2979 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2980 assert_eq!(preds.len(), 2);
2981 assert!(preds.iter().any(|p| p.class.as_ref() == "AdditiveMonoid"));
2982 assert!(preds.iter().any(|p| p.class.as_ref() == "Integral"));
2983 assert!(
2984 preds
2985 .iter()
2986 .all(|p| p.typ == Type::builtin(BuiltinTypeId::I32))
2987 );
2988 }
2989
2990 #[test]
2991 fn infer_mod_defaults() {
2992 let expr = parse_expr("1 % 2");
2993 let mut ts = TypeSystem::new_with_prelude().unwrap();
2994 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2995 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2996 assert_eq!(preds.len(), 1);
2997 assert_eq!(preds[0].class.as_ref(), "Integral");
2998 assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::I32));
2999 }
3000
3001 #[test]
3002 fn infer_get_list_type() {
3003 let expr = parse_expr("get 1 [1, 2, 3]");
3004 let mut ts = TypeSystem::new_with_prelude().unwrap();
3005 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3006 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3007 assert!(preds.iter().any(|p| p.class.as_ref() == "Indexable"));
3008 assert!(preds.iter().all(|p| {
3009 p.class.as_ref() == "Indexable"
3010 || (p.class.as_ref() == "Integral" && p.typ == Type::builtin(BuiltinTypeId::I32))
3011 }));
3012 for pred in preds.iter().filter(|p| p.class.as_ref() == "Indexable") {
3013 assert!(entails(&ts.classes, &[], pred).unwrap());
3014 }
3015 }
3016
3017 #[test]
3018 fn infer_get_tuple_type() {
3019 let expr = parse_expr("(1, 'Hello', true).0");
3020 let mut ts = TypeSystem::new_with_prelude().unwrap();
3021 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3022 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3023 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3024
3025 let expr = parse_expr("(1, 'Hello', true).1");
3026 let mut ts = TypeSystem::new_with_prelude().unwrap();
3027 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3028 assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
3029 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3030
3031 let expr = parse_expr("(1, 'Hello', true).2");
3032 let mut ts = TypeSystem::new_with_prelude().unwrap();
3033 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3034 assert_eq!(ty, Type::builtin(BuiltinTypeId::Bool));
3035 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3036 }
3037
3038 #[test]
3039 fn infer_division_defaults() {
3040 let expr = parse_expr("1.0 / 2.0");
3041 let mut ts = TypeSystem::new_with_prelude().unwrap();
3042 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3043 assert_eq!(ty, Type::builtin(BuiltinTypeId::F32));
3044 assert_eq!(preds.len(), 1);
3045 assert_eq!(preds[0].class.as_ref(), "Field");
3046 assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::F32));
3047 assert!(entails(&ts.classes, &[], &preds[0]).unwrap());
3048 }
3049
3050 #[test]
3051 fn infer_unbound_variable_error() {
3052 let expr = parse_expr("missing");
3053 let mut ts = TypeSystem::new_with_prelude().unwrap();
3054 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3055 assert!(matches!(
3056 err,
3057 TypeError::UnknownVar(name) if name.as_ref() == "missing"
3058 ));
3059 }
3060
3061 #[test]
3062 fn infer_if_branch_type_mismatch_error() {
3063 let expr = parse_expr(r#"if true then 1 else "no""#);
3064 let mut ts = TypeSystem::new_with_prelude().unwrap();
3065 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3066 match err {
3067 TypeError::Unification(a, b) => {
3068 let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
3069 assert!(ok, "expected i32 vs string, got {a} vs {b}");
3070 }
3071 other => panic!("expected unification error, got {other:?}"),
3072 }
3073 }
3074
3075 #[test]
3076 fn infer_unknown_pattern_constructor_error() {
3077 let expr = parse_expr("match 1 when Nope -> 1");
3078 let mut ts = TypeSystem::new_with_prelude().unwrap();
3079 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3080 assert!(matches!(
3081 err,
3082 TypeError::UnknownVar(name) if name.as_ref() == "Nope"
3083 ));
3084 }
3085
3086 #[test]
3087 fn infer_ambiguous_overload_error() {
3088 let mut ts = TypeSystem::new();
3089 let a = TypeVar::new(0, Some(sym("a")));
3090 let b = TypeVar::new(1, Some(sym("b")));
3091 let scheme_a = Scheme::new(vec![a.clone()], vec![], Type::var(a));
3092 let scheme_b = Scheme::new(vec![b.clone()], vec![], Type::var(b));
3093 ts.add_overload(sym("dup"), scheme_a);
3094 ts.add_overload(sym("dup"), scheme_b);
3095 let expr = parse_expr("dup");
3096 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3097 assert!(matches!(
3098 err,
3099 TypeError::AmbiguousOverload(name) if name.as_ref() == "dup"
3100 ));
3101 }
3102
3103 #[test]
3104 fn infer_if_cond_not_bool_error() {
3105 let expr = parse_expr("if 1 then 2 else 3");
3106 let mut ts = TypeSystem::new_with_prelude().unwrap();
3107 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3108 match err {
3109 TypeError::Unification(a, b) => {
3110 let ok = (a == "bool" && b == "i32") || (a == "i32" && b == "bool");
3111 assert!(ok, "expected bool vs i32, got {a} vs {b}");
3112 }
3113 other => panic!("expected unification error, got {other:?}"),
3114 }
3115 }
3116
3117 #[test]
3118 fn infer_apply_non_function_error() {
3119 let expr = parse_expr("1 2");
3120 let mut ts = TypeSystem::new_with_prelude().unwrap();
3121 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3122 assert!(matches!(err, TypeError::Unification(_, _)));
3123 }
3124
3125 #[test]
3126 fn infer_list_element_mismatch_error() {
3127 let expr = parse_expr("[1, true]");
3128 let mut ts = TypeSystem::new_with_prelude().unwrap();
3129 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3130 match err {
3131 TypeError::Unification(a, b) => {
3132 let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
3133 assert!(ok, "expected i32 vs bool, got {a} vs {b}");
3134 }
3135 other => panic!("expected unification error, got {other:?}"),
3136 }
3137 }
3138
3139 #[test]
3140 fn infer_dict_value_mismatch_error() {
3141 let expr = parse_expr("{a = 1, b = true}");
3142 let mut ts = TypeSystem::new_with_prelude().unwrap();
3143 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3144 match err {
3145 TypeError::Unification(a, b) => {
3146 let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
3147 assert!(ok, "expected i32 vs bool, got {a} vs {b}");
3148 }
3149 other => panic!("expected unification error, got {other:?}"),
3150 }
3151 }
3152
3153 #[test]
3154 fn infer_match_list_on_non_list_error() {
3155 let expr = parse_expr("match 1 when [x] -> x");
3156 let mut ts = TypeSystem::new_with_prelude().unwrap();
3157 assert!(infer(&mut ts, expr.as_ref()).is_err());
3158 }
3159
3160 #[test]
3161 fn infer_pattern_constructor_arity_error() {
3162 let expr = parse_expr("match (Ok 1) when Ok x y -> x");
3163 let mut ts = TypeSystem::new_with_prelude().unwrap();
3164 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3165 assert!(matches!(
3166 err,
3167 TypeError::UnsupportedExpr("pattern constructor")
3168 ));
3169 }
3170
3171 #[test]
3172 fn infer_match_arm_type_mismatch_error() {
3173 let expr = parse_expr(r#"match 1 when _ -> 1 when _ -> "no""#);
3174 let mut ts = TypeSystem::new_with_prelude().unwrap();
3175 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3176 match err {
3177 TypeError::Unification(a, b) => {
3178 let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
3179 assert!(ok, "expected i32 vs string, got {a} vs {b}");
3180 }
3181 other => panic!("expected unification error, got {other:?}"),
3182 }
3183 }
3184
3185 #[test]
3186 fn infer_match_option_on_non_option_error() {
3187 let expr = parse_expr("match 1 when Some x -> x");
3188 let mut ts = TypeSystem::new_with_prelude().unwrap();
3189 assert!(infer(&mut ts, expr.as_ref()).is_err());
3190 }
3191
3192 #[test]
3193 fn infer_dict_pattern_on_non_dict_error() {
3194 let expr = parse_expr("match 1 when {a} -> a");
3195 let mut ts = TypeSystem::new_with_prelude().unwrap();
3196 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3197 assert!(matches!(err, TypeError::Unification(_, _)));
3198 }
3199
3200 #[test]
3201 fn infer_cons_pattern_on_non_list_error() {
3202 let expr = parse_expr("match 1 when x::xs -> x");
3203 let mut ts = TypeSystem::new_with_prelude().unwrap();
3204 assert!(infer(&mut ts, expr.as_ref()).is_err());
3205 }
3206
3207 #[test]
3208 fn infer_apply_wrong_arg_type_error() {
3209 let expr = parse_expr("(\\x -> x + 1) true");
3210 let mut ts = TypeSystem::new_with_prelude().unwrap();
3211 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3212 assert!(matches!(err, TypeError::Unification(_, _)));
3213 }
3214
3215 #[test]
3216 fn infer_self_application_occurs_error() {
3217 let expr = parse_expr("\\x -> x x");
3218 let mut ts = TypeSystem::new_with_prelude().unwrap();
3219 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3220 assert!(matches!(err, TypeError::Occurs(_, _)));
3221 }
3222
3223 #[test]
3224 fn infer_apply_constructor_too_many_args_error() {
3225 let expr = parse_expr("Some 1 2");
3226 let mut ts = TypeSystem::new_with_prelude().unwrap();
3227 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3228 assert!(matches!(err, TypeError::Unification(_, _)));
3229 }
3230
3231 #[test]
3232 fn infer_operator_type_mismatch_error() {
3233 let expr = parse_expr("1 + true");
3234 let mut ts = TypeSystem::new_with_prelude().unwrap();
3235 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3236 assert!(matches!(err, TypeError::Unification(_, _)));
3237 }
3238
3239 #[test]
3240 fn infer_non_exhaustive_match_is_error() {
3241 let expr = parse_expr("match (Ok 1) when Ok x -> x");
3242 let mut ts = TypeSystem::new_with_prelude().unwrap();
3243 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3244 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3245 }
3246
3247 #[test]
3248 fn infer_non_exhaustive_match_on_bound_var_error() {
3249 let expr = parse_expr("let x = Ok 1 in match x when Ok y -> y");
3250 let mut ts = TypeSystem::new_with_prelude().unwrap();
3251 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3252 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3253 }
3254
3255 #[test]
3256 fn infer_non_exhaustive_match_in_lambda_error() {
3257 let expr = parse_expr("\\x -> match x when Ok y -> y");
3258 let mut ts = TypeSystem::new_with_prelude().unwrap();
3259 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3260 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3261 }
3262
3263 #[test]
3264 fn infer_non_exhaustive_option_match_error() {
3265 let expr = parse_expr("match (Some 1) when Some x -> x");
3266 let mut ts = TypeSystem::new_with_prelude().unwrap();
3267 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3268 match err {
3269 TypeError::NonExhaustiveMatch { missing, .. } => {
3270 assert_eq!(missing, vec![sym("None")]);
3271 }
3272 other => panic!("expected non-exhaustive match, got {other:?}"),
3273 }
3274 }
3275
3276 #[test]
3277 fn infer_non_exhaustive_result_match_error() {
3278 let expr = parse_expr("match (Err 1) when Ok x -> x");
3279 let mut ts = TypeSystem::new_with_prelude().unwrap();
3280 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3281 match err {
3282 TypeError::NonExhaustiveMatch { missing, .. } => {
3283 assert_eq!(missing, vec![sym("Err")]);
3284 }
3285 other => panic!("expected non-exhaustive match, got {other:?}"),
3286 }
3287 }
3288
3289 #[test]
3290 fn infer_non_exhaustive_list_missing_empty_error() {
3291 let expr = parse_expr("match [1, 2] when x::xs -> x");
3292 let mut ts = TypeSystem::new_with_prelude().unwrap();
3293 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3294 match err {
3295 TypeError::NonExhaustiveMatch { missing, .. } => {
3296 assert_eq!(missing, vec![sym("Empty")]);
3297 }
3298 other => panic!("expected non-exhaustive match, got {other:?}"),
3299 }
3300 }
3301
3302 #[test]
3303 fn infer_non_exhaustive_list_match_on_bound_var_error() {
3304 let expr = parse_expr("let xs = [1, 2] in match xs when x::xs -> x");
3305 let mut ts = TypeSystem::new_with_prelude().unwrap();
3306 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3307 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3308 }
3309
3310 #[test]
3311 fn infer_non_exhaustive_list_missing_cons_error() {
3312 let expr = parse_expr("match [1] when [] -> 0");
3313 let mut ts = TypeSystem::new_with_prelude().unwrap();
3314 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3315 match err {
3316 TypeError::NonExhaustiveMatch { missing, .. } => {
3317 assert_eq!(missing, vec![sym("Cons")]);
3318 }
3319 other => panic!("expected non-exhaustive match, got {other:?}"),
3320 }
3321 }
3322
3323 #[test]
3324 fn infer_match_list_patterns_on_result_error() {
3325 let expr = parse_expr("match (Ok 1) when [] -> 0 when x::xs -> 1");
3326 let mut ts = TypeSystem::new_with_prelude().unwrap();
3327 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3328 assert!(matches!(err, TypeError::Unification(_, _)));
3329 }
3330
3331 #[test]
3332 fn infer_missing_instances_produce_unsatisfied_predicates() {
3333 for (name, code) in [
3334 ("division", "1 / 2"),
3335 ("eq_dict", "{a = 1} == {a = 2}"),
3336 ("min_bool", "min [true]"),
3337 ("map_dict", r#"map (\x -> x) {a = 1}"#),
3338 ] {
3339 let (class, pred_type, expected_ty) = match name {
3340 "division" => (
3341 "Field",
3342 Type::builtin(BuiltinTypeId::I32),
3343 Some(Type::builtin(BuiltinTypeId::I32)),
3344 ),
3345 "eq_dict" => ("Eq", dict_of(Type::builtin(BuiltinTypeId::I32)), None),
3346 "min_bool" => ("Ord", Type::builtin(BuiltinTypeId::Bool), None),
3347 "map_dict" => ("Functor", Type::builtin(BuiltinTypeId::Dict), None),
3348 _ => unreachable!("unknown test case {name}"),
3349 };
3350
3351 let expr = parse_expr(code);
3352 let mut ts = TypeSystem::new_with_prelude().unwrap();
3353 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3354 if let Some(expected) = expected_ty {
3355 assert_eq!(ty, expected, "{name}");
3356 }
3357
3358 let pred = preds
3359 .iter()
3360 .find(|p| p.class.as_ref() == class && p.typ == pred_type)
3361 .unwrap();
3362 assert!(!entails(&ts.classes, &[], pred).unwrap(), "{name}");
3363 }
3364 }
3365
3366 #[test]
3367 fn record_update_single_variant_adt_infers() {
3368 let program = parse_program(
3369 r#"
3370 type Foo = Bar { x: i32, y: i32 }
3371 let
3372 foo: Foo = Bar { x = 1, y = 2 },
3373 bar = { foo with { x = 3 } }
3374 in
3375 bar
3376 "#,
3377 );
3378 let mut ts = TypeSystem::new_with_prelude().unwrap();
3379 ts.register_decls(&program.decls).unwrap();
3380 let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3381 assert_eq!(typ.to_string(), "Foo");
3382 }
3383
3384 #[test]
3385 fn record_update_unknown_field_errors() {
3386 let program = parse_program(
3387 r#"
3388 type Foo = Bar { x: i32 }
3389 let
3390 foo: Foo = Bar { x = 1 }
3391 in
3392 { foo with { y = 2 } }
3393 "#,
3394 );
3395 let mut ts = TypeSystem::new_with_prelude().unwrap();
3396 ts.register_decls(&program.decls).unwrap();
3397 let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
3398 let err = strip_span(err);
3399 assert!(matches!(err, TypeError::UnknownField { .. }));
3400 }
3401
3402 #[test]
3403 fn record_update_requires_refined_variant_for_sum_types() {
3404 let program = parse_program(
3405 r#"
3406 type Foo = Bar { x: i32 } | Baz { x: i32 }
3407 let
3408 f = \ (foo : Foo) -> { foo with { x = 2 } }
3409 in
3410 f (Bar { x = 1 })
3411 "#,
3412 );
3413 let mut ts = TypeSystem::new_with_prelude().unwrap();
3414 ts.register_decls(&program.decls).unwrap();
3415 let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
3416 let err = strip_span(err);
3417 assert!(matches!(err, TypeError::FieldNotKnown { .. }));
3418 }
3419
3420 #[test]
3421 fn record_update_allowed_after_match_refines_variant() {
3422 let program = parse_program(
3423 r#"
3424 type Foo = Bar { x: i32 } | Baz { x: i32 }
3425 let
3426 f = \ (foo : Foo) ->
3427 match foo
3428 when Bar {x} -> { foo with { x = x + 1 } }
3429 when Baz {x} -> { foo with { x = x + 2 } }
3430 in
3431 f (Bar { x = 1 })
3432 "#,
3433 );
3434 let mut ts = TypeSystem::new_with_prelude().unwrap();
3435 ts.register_decls(&program.decls).unwrap();
3436 let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3437 assert_eq!(typ.to_string(), "Foo");
3438 }
3439
3440 #[test]
3441 fn record_update_plain_record_type() {
3442 let program = parse_program(
3443 r#"
3444 let
3445 f = \ (r : { x: i32, y: i32 }) -> { r with { y = 9 } }
3446 in
3447 f { x = 1, y = 2 }
3448 "#,
3449 );
3450 let mut ts = TypeSystem::new_with_prelude().unwrap();
3451 ts.register_decls(&program.decls).unwrap();
3452 let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3453 assert_eq!(typ.to_string(), "{x: i32, y: i32}");
3454 }
3455
3456 #[test]
3457 fn infer_typed_hole_expr_is_hole_kind() {
3458 let expr = parse_expr("?");
3459 let mut ts = TypeSystem::new_with_prelude().unwrap();
3460 let (typed, _preds, _ty) = infer_typed(&mut ts, expr.as_ref()).unwrap();
3461 assert!(
3462 matches!(typed.kind, TypedExprKind::Hole),
3463 "typed={typed:#?}"
3464 );
3465 }
3466
3467 #[test]
3468 fn infer_hole_with_annotation_unifies_to_annotation() {
3469 let expr = parse_expr("let x : i32 = ? in x");
3470 let mut ts = TypeSystem::new_with_prelude().unwrap();
3471 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3472 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3473 }
3474
3475 #[test]
3476 fn infer_hole_in_if_condition_is_bool_constrained() {
3477 let expr = parse_expr("if ? then 1 else 2");
3478 let mut ts = TypeSystem::new_with_prelude().unwrap();
3479 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3480 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3481 }
3482
3483 #[test]
3484 fn infer_hole_in_arithmetic_is_numeric_constrained() {
3485 let expr = parse_expr("? + 1");
3486 let mut ts = TypeSystem::new_with_prelude().unwrap();
3487 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3488 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3489 }
3490
3491 #[test]
3492 fn infer_hole_arithmetic_conflicting_annotation_failure() {
3493 let expr = parse_expr("let x : string = (? + 1) in x");
3494 let mut ts = TypeSystem::new_with_prelude().unwrap();
3495 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3496 assert!(matches!(err, TypeError::Unification(_, _)), "err={err:#?}");
3497 }
3498}