1use super::*;
2
3fn dedup_preds(preds: Vec<Predicate>) -> Vec<Predicate> {
4 let mut seen = HashSet::new();
5 let mut out = Vec::with_capacity(preds.len());
6 for pred in preds {
7 if seen.insert(pred.clone()) {
8 out.push(pred);
9 }
10 }
11 out
12}
13
14fn is_integral_primitive(typ: &Type) -> bool {
15 matches!(
16 typ.as_ref(),
17 TypeKind::Con(TypeConst {
18 builtin_id: Some(
19 BuiltinTypeId::U8
20 | BuiltinTypeId::U16
21 | BuiltinTypeId::U32
22 | BuiltinTypeId::U64
23 | BuiltinTypeId::I8
24 | BuiltinTypeId::I16
25 | BuiltinTypeId::I32
26 | BuiltinTypeId::I64
27 ),
28 ..
29 })
30 )
31}
32
33fn finalize_infer_for_public_api(
34 mut preds: Vec<Predicate>,
35 mut typ: Type,
36) -> Result<(Vec<Predicate>, Type), TypeError> {
37 let mut subst = Subst::new_sync();
38 for pred in &preds {
39 if pred.class.as_ref() == "Integral"
40 && let TypeKind::Var(tv) = pred.typ.as_ref()
41 {
42 subst = subst.insert(tv.id, Type::builtin(BuiltinTypeId::I32));
43 }
44 }
45
46 if !subst_is_empty(&subst) {
47 preds = dedup_preds(preds.apply(&subst));
48 typ = typ.apply(&subst);
49 }
50
51 for pred in &preds {
52 if pred.class.as_ref() != "Integral" {
53 continue;
54 }
55 if matches!(pred.typ.as_ref(), TypeKind::Var(_)) || is_integral_primitive(&pred.typ) {
56 continue;
57 }
58 return Err(TypeError::Unification("i32".into(), pred.typ.to_string()));
59 }
60
61 Ok((preds, typ))
62}
63
64#[derive(Clone, Debug)]
65struct KnownVariant {
66 adt: Symbol,
67 variant: Symbol,
68}
69
70type KnownVariants = HashMap<Symbol, KnownVariant>;
71
72fn apply_scheme_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> Scheme {
73 let preds = scheme
74 .preds
75 .iter()
76 .map(|pred| Predicate::new(pred.class.clone(), unifier.apply_type(&pred.typ)))
77 .collect();
78 let typ = unifier.apply_type(&scheme.typ);
79 Scheme::new(scheme.vars.clone(), preds, typ)
80}
81
82fn scheme_ftv_with_unifier(scheme: &Scheme, unifier: &mut Unifier<'_>) -> HashSet<TypeVarId> {
83 let mut ftv = unifier.apply_type(&scheme.typ).ftv();
84 for pred in &scheme.preds {
85 ftv.extend(unifier.apply_type(&pred.typ).ftv());
86 }
87 for var in &scheme.vars {
88 ftv.remove(&var.id);
89 }
90 ftv
91}
92
93fn env_ftv_with_unifier(env: &TypeEnv, unifier: &mut Unifier<'_>) -> HashSet<TypeVarId> {
94 let mut out = HashSet::new();
95 for (_name, schemes) in env.values.iter() {
96 for scheme in schemes {
97 out.extend(scheme_ftv_with_unifier(scheme, unifier));
98 }
99 }
100 out
101}
102
103fn generalize_with_unifier(
104 env: &TypeEnv,
105 preds: Vec<Predicate>,
106 typ: Type,
107 unifier: &mut Unifier<'_>,
108) -> Scheme {
109 let preds: Vec<Predicate> = preds
110 .into_iter()
111 .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
112 .collect();
113 let typ = unifier.apply_type(&typ);
114 let mut vars: Vec<TypeVar> = typ
115 .ftv()
116 .union(&preds.ftv())
117 .copied()
118 .collect::<HashSet<_>>()
119 .difference(&env_ftv_with_unifier(env, unifier))
120 .cloned()
121 .map(|id| TypeVar::new(id, None))
122 .collect();
123 vars.sort_by_key(|v| v.id);
124 Scheme::new(vars, preds, typ)
125}
126
127fn monomorphic_scheme_with_unifier(
128 preds: Vec<Predicate>,
129 typ: Type,
130 unifier: &mut Unifier<'_>,
131) -> Scheme {
132 let preds = dedup_preds(
133 preds
134 .into_iter()
135 .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
136 .collect(),
137 );
138 let typ = unifier.apply_type(&typ);
139 Scheme::new(vec![], preds, typ)
140}
141
142pub fn infer_typed(
143 type_system: &mut TypeSystem,
144 expr: &Expr,
145) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
146 infer_typed_inner(type_system, expr)
147}
148
149pub fn infer_typed_with_gas(
150 type_system: &mut TypeSystem,
151 expr: &Expr,
152 gas: &mut GasMeter,
153) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
154 let known = KnownVariants::new();
155 let mut unifier = Unifier::with_gas(gas, type_system.limits.max_infer_depth);
156 let (preds, t, typed) = infer_expr(
157 &mut unifier,
158 &mut type_system.supply,
159 &type_system.env,
160 &type_system.adts,
161 &known,
162 expr,
163 )
164 .map_err(|err| with_span(expr.span(), err))?;
165 let subst = unifier.into_subst();
166 let mut typed = typed.apply(&subst);
167 let mut preds = dedup_preds(preds.apply(&subst));
168 let mut t = t.apply(&subst);
169 let improve = improve_indexable(&preds)?;
170 if !subst_is_empty(&improve) {
171 typed = typed.apply(&improve);
172 preds = dedup_preds(preds.apply(&improve));
173 t = t.apply(&improve);
174 }
175 type_system.check_predicate_kinds(&preds)?;
176 Ok((typed, preds, t))
177}
178
179fn infer_typed_inner(
180 type_system: &mut TypeSystem,
181 expr: &Expr,
182) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
183 let known = KnownVariants::new();
184 let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
185 let (preds, t, typed) = infer_expr(
186 &mut unifier,
187 &mut type_system.supply,
188 &type_system.env,
189 &type_system.adts,
190 &known,
191 expr,
192 )
193 .map_err(|err| with_span(expr.span(), err))?;
194 let subst = unifier.into_subst();
195 let mut typed = typed.apply(&subst);
196 let mut preds = dedup_preds(preds.apply(&subst));
197 let mut t = t.apply(&subst);
198 let improve = improve_indexable(&preds)?;
199 if !subst_is_empty(&improve) {
200 typed = typed.apply(&improve);
201 preds = dedup_preds(preds.apply(&improve));
202 t = t.apply(&improve);
203 }
204 type_system.check_predicate_kinds(&preds)?;
205 Ok((typed, preds, t))
206}
207
208pub fn infer(
209 type_system: &mut TypeSystem,
210 expr: &Expr,
211) -> Result<(Vec<Predicate>, Type), TypeError> {
212 infer_inner(type_system, expr)
213}
214
215pub fn infer_with_gas(
216 type_system: &mut TypeSystem,
217 expr: &Expr,
218 gas: &mut GasMeter,
219) -> Result<(Vec<Predicate>, Type), TypeError> {
220 let known = KnownVariants::new();
221 let mut unifier = Unifier::with_gas(gas, type_system.limits.max_infer_depth);
222 let (preds, t) = infer_expr_type(
223 &mut unifier,
224 &mut type_system.supply,
225 &type_system.env,
226 &type_system.adts,
227 &known,
228 expr,
229 )
230 .map_err(|err| with_span(expr.span(), err))?;
231 let subst = unifier.into_subst();
232 let preds = dedup_preds(preds.apply(&subst));
233 let t = t.apply(&subst);
234 type_system.check_predicate_kinds(&preds)?;
235 finalize_infer_for_public_api(preds, t)
236}
237
238fn infer_inner(
239 type_system: &mut TypeSystem,
240 expr: &Expr,
241) -> Result<(Vec<Predicate>, Type), TypeError> {
242 let known = KnownVariants::new();
243 let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
244 let (preds, t) = infer_expr_type(
245 &mut unifier,
246 &mut type_system.supply,
247 &type_system.env,
248 &type_system.adts,
249 &known,
250 expr,
251 )
252 .map_err(|err| with_span(expr.span(), err))?;
253 let subst = unifier.into_subst();
254 let mut preds = dedup_preds(preds.apply(&subst));
255 let mut t = t.apply(&subst);
256 let improve = improve_indexable(&preds)?;
257 if !subst_is_empty(&improve) {
258 preds = dedup_preds(preds.apply(&improve));
259 t = t.apply(&improve);
260 }
261 type_system.check_predicate_kinds(&preds)?;
262 finalize_infer_for_public_api(preds, t)
263}
264
265fn improve_indexable(preds: &[Predicate]) -> Result<Subst, TypeError> {
266 let mut subst = Subst::new_sync();
267 loop {
268 let mut changed = false;
269 for pred in preds {
270 let pred = pred.apply(&subst);
271 if pred.class.as_ref() != "Indexable" {
272 continue;
273 }
274 let TypeKind::Tuple(parts) = pred.typ.as_ref() else {
275 continue;
276 };
277 if parts.len() != 2 {
278 continue;
279 }
280 let container = parts[0].clone();
281 let elem = parts[1].clone();
282 let s = indexable_elem_subst(&container, &elem)?;
283 if !subst_is_empty(&s) {
284 subst = compose_subst(s, subst);
285 changed = true;
286 }
287 }
288 if !changed {
289 break;
290 }
291 }
292 Ok(subst)
293}
294
295fn indexable_elem_subst(container: &Type, elem: &Type) -> Result<Subst, TypeError> {
296 match container.as_ref() {
297 TypeKind::App(head, arg) => match head.as_ref() {
298 TypeKind::Con(tc)
299 if matches!(
300 tc.builtin_id,
301 Some(BuiltinTypeId::List | BuiltinTypeId::Array)
302 ) =>
303 {
304 unify(elem, arg)
305 }
306 _ => Ok(Subst::new_sync()),
307 },
308 TypeKind::Tuple(elems) => {
309 if elems.is_empty() {
310 return Ok(Subst::new_sync());
311 }
312 let mut subst = Subst::new_sync();
313 let mut cur = elems[0].clone();
314 for ty in elems.iter().skip(1) {
315 let s_next = unify(&cur.apply(&subst), &ty.apply(&subst))?;
316 subst = compose_subst(s_next, subst);
317 cur = cur.apply(&subst);
318 }
319 let elem = elem.apply(&subst);
320 let s_elem = unify(&elem, &cur.apply(&subst))?;
321 Ok(compose_subst(s_elem, subst))
322 }
323 _ => Ok(Subst::new_sync()),
324 }
325}
326
327type LambdaChain<'a> = (
328 Vec<(Symbol, Option<TypeExpr>)>,
329 Vec<TypeConstraint>,
330 &'a Expr,
331);
332
333fn collect_lambda_chain<'a>(expr: &'a Expr) -> LambdaChain<'a> {
334 let mut params = Vec::new();
335 let mut constraints = Vec::new();
336 let mut cur = expr;
337 let mut seen_constraints = false;
338 while let Expr::Lam(_, _scope, param, ann, lam_constraints, body) = cur {
339 if !lam_constraints.is_empty() {
340 if seen_constraints {
341 break;
342 }
343 constraints = lam_constraints.clone();
344 seen_constraints = true;
345 }
346 params.push((param.name.clone(), ann.clone()));
347 cur = body.as_ref();
348 }
349 (params, constraints, cur)
350}
351
352fn collect_app_chain(expr: &Expr) -> (&Expr, Vec<&Expr>) {
353 let mut args = Vec::new();
354 let mut cur = expr;
355 while let Expr::App(_, f, x) = cur {
356 args.push(x.as_ref());
357 cur = f.as_ref();
358 }
359 args.reverse();
360 (cur, args)
361}
362
363fn narrow_overload_candidates(candidates: &[Type], arg_ty: &Type) -> Vec<Type> {
364 let mut out = Vec::new();
365 for candidate in candidates {
366 let Some((params, ret)) = decompose_fun(candidate, 1) else {
367 continue;
368 };
369 let param = ¶ms[0];
370 if let Ok(s) = unify(param, arg_ty) {
371 out.push(ret.apply(&s));
372 }
373 }
374 out
375}
376
377fn unary_app_arg(typ: &Type, ctor_name: &str) -> Option<Type> {
378 let TypeKind::App(head, arg) = typ.as_ref() else {
379 return None;
380 };
381 let TypeKind::Con(tc) = head.as_ref() else {
382 return None;
383 };
384 (tc.name.as_ref() == ctor_name && tc.arity == 1).then(|| arg.clone())
385}
386
387fn infer_app_arg_type(
388 unifier: &mut Unifier<'_>,
389 supply: &mut TypeVarSupply,
390 env: &TypeEnv,
391 adts: &HashMap<Symbol, AdtDecl>,
392 known: &KnownVariants,
393 arg_hint: Option<Type>,
394 arg: &Expr,
395) -> Result<(Vec<Predicate>, Type), TypeError> {
396 match (arg_hint, arg) {
397 (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
398 infer_record_update_type_with_hint(
399 unifier,
400 supply,
401 env,
402 adts,
403 known,
404 base.as_ref(),
405 updates,
406 &arg_hint,
407 )
408 }
409 (Some(arg_hint), Expr::Dict(_, kvs))
410 if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
411 {
412 let TypeKind::Record(fields) = arg_hint.as_ref() else {
413 unreachable!("guarded by matches!")
414 };
415 let expected: HashMap<_, _> =
416 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
417 let mut seen = HashSet::new();
418 let mut preds = Vec::new();
419 for (k, v) in kvs {
420 let expected_ty = expected
421 .get(k)
422 .ok_or_else(|| TypeError::UnknownField {
423 field: k.clone(),
424 typ: Type::record(fields.clone()).to_string(),
425 })?
426 .clone();
427 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
428 unifier.unify(&t1, &expected_ty)?;
429 preds.extend(p1);
430 seen.insert(k.clone());
431 }
432 for key in expected.keys() {
433 if !seen.contains(key.as_ref()) {
434 return Err(TypeError::UnknownField {
435 field: key.clone(),
436 typ: Type::record(fields.clone()).to_string(),
437 });
438 }
439 }
440 let record_ty = Type::record(
441 fields
442 .iter()
443 .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
444 .collect(),
445 );
446 Ok((preds, record_ty))
447 }
448 _ => infer_expr_type(unifier, supply, env, adts, known, arg),
449 }
450}
451
452fn infer_app_arg_typed(
453 unifier: &mut Unifier<'_>,
454 supply: &mut TypeVarSupply,
455 env: &TypeEnv,
456 adts: &HashMap<Symbol, AdtDecl>,
457 known: &KnownVariants,
458 arg_hint: Option<Type>,
459 arg: &Expr,
460) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
461 match (arg_hint, arg) {
462 (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
463 infer_record_update_typed_with_hint(
464 unifier,
465 supply,
466 env,
467 adts,
468 known,
469 base.as_ref(),
470 updates,
471 &arg_hint,
472 )
473 }
474 (Some(arg_hint), Expr::Dict(_, kvs))
475 if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
476 {
477 let TypeKind::Record(fields) = arg_hint.as_ref() else {
478 unreachable!("guarded by matches!")
479 };
480 let mut preds = Vec::new();
481 let mut typed_kvs = BTreeMap::new();
482 let expected: HashMap<_, _> =
483 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
484 for (k, v) in kvs {
485 let expected_ty = expected
486 .get(k)
487 .ok_or_else(|| TypeError::UnknownField {
488 field: k.clone(),
489 typ: Type::record(fields.clone()).to_string(),
490 })?
491 .clone();
492 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
493 unifier.unify(&t1, &expected_ty)?;
494 preds.extend(p1);
495 typed_kvs.insert(k.clone(), typed_v);
496 }
497 for key in expected.keys() {
498 if !typed_kvs.contains_key(key.as_ref()) {
499 return Err(TypeError::UnknownField {
500 field: key.clone(),
501 typ: Type::record(fields.clone()).to_string(),
502 });
503 }
504 }
505 let record_ty = Type::record(
506 fields
507 .iter()
508 .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
509 .collect(),
510 );
511 let typed = TypedExpr::new(record_ty.clone(), TypedExprKind::Dict(typed_kvs));
512 Ok((preds, record_ty, typed))
513 }
514 _ => infer_expr(unifier, supply, env, adts, known, arg),
515 }
516}
517
518#[allow(clippy::too_many_arguments)]
519fn infer_record_update_type_with_hint(
520 unifier: &mut Unifier<'_>,
521 supply: &mut TypeVarSupply,
522 env: &TypeEnv,
523 adts: &HashMap<Symbol, AdtDecl>,
524 known: &KnownVariants,
525 base: &Expr,
526 updates: &BTreeMap<Symbol, Arc<Expr>>,
527 hint_ty: &Type,
528) -> Result<(Vec<Predicate>, Type), TypeError> {
529 let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
530 unifier.unify(&t_base, hint_ty)?;
531 let base_ty = unifier.apply_type(&t_base);
532 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
533 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
534 let (result_ty, fields) = resolve_record_update(
535 unifier,
536 supply,
537 adts,
538 &base_ty,
539 known_variant,
540 &update_fields,
541 )?;
542 let expected: HashMap<_, _> = fields.into_iter().collect();
543
544 let mut preds = p_base;
545 for (k, v) in updates {
546 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
547 field: k.clone(),
548 typ: result_ty.to_string(),
549 })?;
550 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
551 unifier.unify(&t1, expected_ty)?;
552 preds.extend(p1);
553 }
554 Ok((preds, result_ty))
555}
556
557#[allow(clippy::too_many_arguments)]
558fn infer_record_update_typed_with_hint(
559 unifier: &mut Unifier<'_>,
560 supply: &mut TypeVarSupply,
561 env: &TypeEnv,
562 adts: &HashMap<Symbol, AdtDecl>,
563 known: &KnownVariants,
564 base: &Expr,
565 updates: &BTreeMap<Symbol, Arc<Expr>>,
566 hint_ty: &Type,
567) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
568 let (p_base, t_base, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
569 unifier.unify(&t_base, hint_ty)?;
570 let base_ty = unifier.apply_type(&t_base);
571 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
572 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
573 let (result_ty, fields) = resolve_record_update(
574 unifier,
575 supply,
576 adts,
577 &base_ty,
578 known_variant,
579 &update_fields,
580 )?;
581 let expected: HashMap<_, _> = fields.into_iter().collect();
582
583 let mut preds = p_base;
584 let mut typed_updates = BTreeMap::new();
585 for (k, v) in updates {
586 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
587 field: k.clone(),
588 typ: result_ty.to_string(),
589 })?;
590 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
591 unifier.unify(&t1, expected_ty)?;
592 preds.extend(p1);
593 typed_updates.insert(k.clone(), typed_v);
594 }
595
596 let typed = TypedExpr::new(
597 result_ty.clone(),
598 TypedExprKind::RecordUpdate {
599 base: Box::new(typed_base),
600 updates: typed_updates,
601 },
602 );
603 Ok((preds, result_ty, typed))
604}
605
606fn infer_expr_type(
607 unifier: &mut Unifier<'_>,
608 supply: &mut TypeVarSupply,
609 env: &TypeEnv,
610 adts: &HashMap<Symbol, AdtDecl>,
611 known: &KnownVariants,
612 expr: &Expr,
613) -> Result<(Vec<Predicate>, Type), TypeError> {
614 let span = *expr.span();
615 let res = unifier.with_infer_depth(span, |unifier| {
616 infer_expr_type_inner(unifier, supply, env, adts, known, expr)
617 });
618 res.map_err(|err| with_span(&span, err))
619}
620
621fn infer_expr_type_inner(
622 unifier: &mut Unifier<'_>,
623 supply: &mut TypeVarSupply,
624 env: &TypeEnv,
625 adts: &HashMap<Symbol, AdtDecl>,
626 known: &KnownVariants,
627 expr: &Expr,
628) -> Result<(Vec<Predicate>, Type), TypeError> {
629 unifier.charge_infer_node()?;
630 match expr {
631 Expr::Bool(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Bool))),
632 Expr::Uint(_, _) => {
633 let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
634 Ok((vec![Predicate::new("Integral", lit_ty.clone())], lit_ty))
635 }
636 Expr::Int(_, _) => {
637 let lit_ty = Type::var(supply.fresh(Some(sym("n"))));
638 Ok((
639 vec![
640 Predicate::new("Integral", lit_ty.clone()),
641 Predicate::new("AdditiveGroup", lit_ty.clone()),
642 ],
643 lit_ty,
644 ))
645 }
646 Expr::Float(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::F32))),
647 Expr::String(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::String))),
648 Expr::Uuid(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Uuid))),
649 Expr::DateTime(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::DateTime))),
650 Expr::Hole(_) => {
651 let t = Type::var(supply.fresh(Some(sym("hole"))));
652 Ok((vec![], t))
653 }
654 Expr::Var(var) => {
655 let schemes = env
656 .lookup(&var.name)
657 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
658 if schemes.len() == 1 {
659 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
660 let (preds, t) = instantiate(&scheme, supply);
661 Ok((preds, t))
662 } else {
663 for scheme in schemes {
664 if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
665 return Err(TypeError::AmbiguousOverload(var.name.clone()));
666 }
667 }
668 let t = Type::var(supply.fresh(Some(var.name.clone())));
669 Ok((vec![], t))
670 }
671 }
672 Expr::Lam(..) => {
673 let (params, constraints, body) = collect_lambda_chain(expr);
674 let mut ann_vars = HashMap::new();
675 let mut param_tys = Vec::with_capacity(params.len());
676 for (name, ann) in ¶ms {
677 let param_ty = match ann {
678 Some(ann) => type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?,
679 None => Type::var(supply.fresh(Some(name.clone()))),
680 };
681 param_tys.push((name.clone(), param_ty));
682 }
683
684 let mut env1 = env.clone();
685 let mut known_body = known.clone();
686 for (name, param_ty) in ¶m_tys {
687 env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
688 known_body.remove(name);
689 }
690
691 let (mut preds, body_ty) =
692 infer_expr_type(unifier, supply, &env1, adts, &known_body, body)?;
693 let constraint_preds =
694 predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
695 preds.extend(constraint_preds);
696
697 let mut fun_ty = unifier.apply_type(&body_ty);
698 for (_, param_ty) in param_tys.iter().rev() {
699 fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
700 }
701 Ok((preds, fun_ty))
702 }
703 Expr::App(..) => {
704 let (head, args) = collect_app_chain(expr);
705 let (mut preds, mut func_ty) =
706 infer_expr_type(unifier, supply, env, adts, known, head)?;
707 let mut overload_name = None;
708 let mut overload_candidates = if let Expr::Var(var) = head {
709 if let Some(schemes) = env.lookup(&var.name) {
710 if schemes.len() <= 1 {
711 None
712 } else {
713 let mut candidates = Vec::new();
714 for scheme in schemes {
715 if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
716 return Err(TypeError::AmbiguousOverload(var.name.clone()));
717 }
718 let scheme = apply_scheme_with_unifier(scheme, unifier);
719 let (p, typ) = instantiate(&scheme, supply);
720 if !p.is_empty() {
721 return Err(TypeError::AmbiguousOverload(var.name.clone()));
722 }
723 candidates.push(typ);
724 }
725 overload_name = Some(var.name.clone());
726 Some(candidates)
727 }
728 } else {
729 None
730 }
731 } else {
732 None
733 };
734 for arg in args {
735 let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
736 TypeKind::Fun(arg, _) => Some(arg.clone()),
737 _ => None,
738 };
739 let (p_arg, arg_ty) =
740 infer_app_arg_type(unifier, supply, env, adts, known, arg_hint, arg)?;
741 let arg_ty = unifier.apply_type(&arg_ty);
742 if let Some(candidates) = overload_candidates.take() {
743 let candidates = candidates
744 .into_iter()
745 .map(|t| unifier.apply_type(&t))
746 .collect::<Vec<_>>();
747 let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
748 if narrowed.is_empty()
749 && let Some(name) = &overload_name
750 {
751 return Err(TypeError::AmbiguousOverload(name.clone()));
752 }
753 overload_candidates = Some(narrowed);
754 }
755 let res_ty = match overload_candidates.as_ref() {
756 Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
757 _ => Type::var(supply.fresh(Some("r".into()))),
758 };
759 unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
760 preds.extend(p_arg);
761 func_ty = match overload_candidates.as_ref() {
762 Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
763 _ => unifier.apply_type(&res_ty),
764 };
765 }
766 Ok((preds, func_ty))
767 }
768 Expr::Project(_, base, field) => {
769 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, base)?;
770 let base_ty = unifier.apply_type(&t1);
771 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
772 let field_ty =
773 resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
774 Ok((p1, field_ty))
775 }
776 Expr::RecordUpdate(_, base, updates) => {
777 let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
778 let base_ty = unifier.apply_type(&t_base);
779 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
780 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
781 let (result_ty, fields) = resolve_record_update(
782 unifier,
783 supply,
784 adts,
785 &base_ty,
786 known_variant,
787 &update_fields,
788 )?;
789 let expected: HashMap<_, _> = fields.into_iter().collect();
790
791 let mut preds = p_base;
792 for (k, v) in updates {
793 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
794 field: k.clone(),
795 typ: result_ty.to_string(),
796 })?;
797 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
798 unifier.unify(&t1, expected_ty)?;
799 preds.extend(p1);
800 }
801 Ok((preds, result_ty))
802 }
803 Expr::Let(..) => {
804 let mut bindings = Vec::new();
805 let mut cur = expr;
806 while let Expr::Let(_, v, ann, d, b) = cur {
807 bindings.push((v.clone(), ann.clone(), d.clone()));
808 cur = b.as_ref();
809 }
810
811 let mut env_cur = env.clone();
812 let mut known_cur = known.clone();
813 for (v, ann, d) in bindings {
814 let (p1, t1) = if let Some(ref ann_expr) = ann {
815 let mut ann_vars = HashMap::new();
816 let ann_ty =
817 type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
818 match d.as_ref() {
819 Expr::RecordUpdate(_, base, updates) => infer_record_update_type_with_hint(
820 unifier,
821 supply,
822 &env_cur,
823 adts,
824 &known_cur,
825 base.as_ref(),
826 updates,
827 &ann_ty,
828 )?,
829 _ => {
830 let (p1, t1) =
831 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?;
832 unifier.unify(&t1, &ann_ty)?;
833 (p1, t1)
834 }
835 }
836 } else {
837 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?
838 };
839 let def_ty = unifier.apply_type(&t1);
840 let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
841 monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
842 } else {
843 let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
844 reject_ambiguous_scheme(&scheme)?;
845 scheme
846 };
847 env_cur.extend(v.name.clone(), scheme);
848 if let Some(known_variant) =
849 known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
850 {
851 known_cur.insert(
852 v.name.clone(),
853 KnownVariant {
854 adt: known_variant.adt,
855 variant: known_variant.variant,
856 },
857 );
858 } else {
859 known_cur.remove(&v.name);
860 }
861 }
862
863 let (p_body, t_body) =
864 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, cur)?;
865 Ok((p_body, t_body))
866 }
867 Expr::LetRec(_, bindings, body) => {
868 let mut env_seed = env.clone();
869 let mut known_seed = known.clone();
870 let mut binding_tys = HashMap::new();
871 for (var, _ann, _def) in bindings {
872 let tv = Type::var(supply.fresh(Some(var.name.clone())));
873 env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
874 known_seed.remove(&var.name);
875 binding_tys.insert(var.name.clone(), tv);
876 }
877
878 let mut inferred = Vec::with_capacity(bindings.len());
879 for (var, ann, def) in bindings {
880 let (preds, def_ty) =
881 infer_expr_type(unifier, supply, &env_seed, adts, &known_seed, def)?;
882 if let Some(ann) = ann {
883 let mut ann_vars = HashMap::new();
884 let ann_ty = type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
885 unifier.unify(&def_ty, &ann_ty)?;
886 }
887 let binding_ty = binding_tys
888 .get(&var.name)
889 .cloned()
890 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
891 unifier.unify(&binding_ty, &def_ty)?;
892 let resolved_ty = unifier.apply_type(&binding_ty);
893
894 if let Some(known_variant) =
895 known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
896 {
897 known_seed.insert(
898 var.name.clone(),
899 KnownVariant {
900 adt: known_variant.adt,
901 variant: known_variant.variant,
902 },
903 );
904 } else {
905 known_seed.remove(&var.name);
906 }
907 inferred.push((var.name.clone(), preds, resolved_ty));
908 }
909
910 let mut env_body = env.clone();
911 for (name, preds, def_ty) in inferred {
912 let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
913 reject_ambiguous_scheme(&scheme)?;
914 env_body.extend(name, scheme);
915 }
916
917 let (p_body, t_body) =
918 infer_expr_type(unifier, supply, &env_body, adts, &known_seed, body)?;
919 Ok((p_body, t_body))
920 }
921 Expr::Ite(_, cond, then_expr, else_expr) => {
922 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, cond)?;
923 unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
924 let (p2, t2) = infer_expr_type(unifier, supply, env, adts, known, then_expr)?;
925 let (p3, t3) = infer_expr_type(unifier, supply, env, adts, known, else_expr)?;
926 unifier.unify(&t2, &t3)?;
927 let out_ty = unifier.apply_type(&t2);
928 let mut preds = p1;
929 preds.extend(p2);
930 preds.extend(p3);
931 Ok((preds, out_ty))
932 }
933 Expr::Tuple(_, elems) => {
934 let mut preds = Vec::new();
935 let mut types = Vec::new();
936 for elem in elems {
937 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
938 preds.extend(p1);
939 types.push(unifier.apply_type(&t1));
940 }
941 let tuple_ty = Type::tuple(types);
942 Ok((preds, tuple_ty))
943 }
944 Expr::List(_, elems) => {
945 let elem_tv = Type::var(supply.fresh(Some("a".into())));
946 let mut preds = Vec::new();
947 for elem in elems {
948 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
949 unifier.unify(&t1, &elem_tv)?;
950 preds.extend(p1);
951 }
952 let list_ty = Type::app(
953 Type::builtin(BuiltinTypeId::List),
954 unifier.apply_type(&elem_tv),
955 );
956 Ok((preds, list_ty))
957 }
958 Expr::Dict(_, kvs) => {
959 let elem_tv = Type::var(supply.fresh(Some("v".into())));
960 let mut preds = Vec::new();
961 for v in kvs.values() {
962 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
963 unifier.unify(&t1, &elem_tv)?;
964 preds.extend(p1);
965 }
966 let dict_ty = Type::app(
967 Type::builtin(BuiltinTypeId::Dict),
968 unifier.apply_type(&elem_tv),
969 );
970 Ok((preds, dict_ty))
971 }
972 Expr::Match(_, scrutinee, arms) => {
973 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, scrutinee.as_ref())?;
974 let mut preds = p1;
975 let res_ty = Type::var(supply.fresh(Some("match".into())));
976 let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
977
978 for (pat, expr) in arms {
979 let scrutinee_ty = unifier.apply_type(&t1);
980 let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
981 preds.extend(p_pat);
982
983 let mut env_arm = env.clone();
984 for (name, ty) in binds {
985 env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
986 }
987 let mut known_arm = known.clone();
988 if let Expr::Var(var) = scrutinee.as_ref() {
989 match pat {
990 Pattern::Named(_, name, _) => {
991 let name_sym = name.to_dotted_symbol();
992 if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
993 known_arm.insert(
994 var.name.clone(),
995 KnownVariant {
996 adt: adt.name.clone(),
997 variant: name_sym,
998 },
999 );
1000 } else {
1001 known_arm.remove(&var.name);
1002 }
1003 }
1004 _ => {
1005 known_arm.remove(&var.name);
1006 }
1007 }
1008 }
1009 let (p_expr, t_expr) =
1010 infer_expr_type(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1011 unifier.unify(&res_ty, &t_expr)?;
1012 preds.extend(p_expr);
1013 }
1014
1015 let scrutinee_ty = unifier.apply_type(&t1);
1016 check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1017 let out_ty = unifier.apply_type(&res_ty);
1018 Ok((preds, out_ty))
1019 }
1020 Expr::Ann(_, expr, ann) => {
1021 let ann_ty = type_from_annotation_expr(adts, ann)?;
1022 match expr.as_ref() {
1023 Expr::RecordUpdate(_, base, updates) => {
1024 let (preds, out_ty) = infer_record_update_type_with_hint(
1025 unifier,
1026 supply,
1027 env,
1028 adts,
1029 known,
1030 base.as_ref(),
1031 updates,
1032 &ann_ty,
1033 )?;
1034 Ok((preds, out_ty))
1035 }
1036 _ => {
1037 let (preds, expr_ty) =
1038 infer_expr_type(unifier, supply, env, adts, known, expr)?;
1039 unifier.unify(&expr_ty, &ann_ty)?;
1040 let out_ty = unifier.apply_type(&ann_ty);
1041 Ok((preds, out_ty))
1042 }
1043 }
1044 }
1045 }
1046}
1047
1048fn infer_expr(
1049 unifier: &mut Unifier<'_>,
1050 supply: &mut TypeVarSupply,
1051 env: &TypeEnv,
1052 adts: &HashMap<Symbol, AdtDecl>,
1053 known: &KnownVariants,
1054 expr: &Expr,
1055) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
1056 let span = *expr.span();
1057 let res = unifier.with_infer_depth(span, |unifier| {
1058 (|| {
1059 unifier.charge_infer_node()?;
1060 match expr {
1061 Expr::Bool(_, v) => {
1062 let t = Type::builtin(BuiltinTypeId::Bool);
1063 Ok((
1064 vec![],
1065 t.clone(),
1066 TypedExpr::new(t, TypedExprKind::Bool(*v)),
1067 ))
1068 }
1069 Expr::Uint(_, v) => {
1070 let t = Type::var(supply.fresh(Some(sym("n"))));
1071 Ok((
1072 vec![Predicate::new("Integral", t.clone())],
1073 t.clone(),
1074 TypedExpr::new(t, TypedExprKind::Uint(*v)),
1075 ))
1076 }
1077 Expr::Int(_, v) => {
1078 let t = Type::var(supply.fresh(Some(sym("n"))));
1079 Ok((
1080 vec![
1081 Predicate::new("Integral", t.clone()),
1082 Predicate::new("AdditiveGroup", t.clone()),
1083 ],
1084 t.clone(),
1085 TypedExpr::new(t, TypedExprKind::Int(*v)),
1086 ))
1087 }
1088 Expr::Float(_, v) => {
1089 let t = Type::builtin(BuiltinTypeId::F32);
1090 Ok((
1091 vec![],
1092 t.clone(),
1093 TypedExpr::new(t, TypedExprKind::Float(*v)),
1094 ))
1095 }
1096 Expr::String(_, v) => {
1097 let t = Type::builtin(BuiltinTypeId::String);
1098 Ok((
1099 vec![],
1100 t.clone(),
1101 TypedExpr::new(t, TypedExprKind::String(v.clone())),
1102 ))
1103 }
1104 Expr::Uuid(_, v) => {
1105 let t = Type::builtin(BuiltinTypeId::Uuid);
1106 Ok((
1107 vec![],
1108 t.clone(),
1109 TypedExpr::new(t, TypedExprKind::Uuid(*v)),
1110 ))
1111 }
1112 Expr::DateTime(_, v) => {
1113 let t = Type::builtin(BuiltinTypeId::DateTime);
1114 Ok((
1115 vec![],
1116 t.clone(),
1117 TypedExpr::new(t, TypedExprKind::DateTime(*v)),
1118 ))
1119 }
1120 Expr::Hole(_) => {
1121 let t = Type::var(supply.fresh(Some(sym("hole"))));
1122 Ok((vec![], t.clone(), TypedExpr::new(t, TypedExprKind::Hole)))
1123 }
1124 Expr::Var(var) => {
1125 let schemes = env
1126 .lookup(&var.name)
1127 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1128 if schemes.len() == 1 {
1129 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1130 let (preds, t) = instantiate(&scheme, supply);
1131 let typed = TypedExpr::new(
1132 t.clone(),
1133 TypedExprKind::Var {
1134 name: var.name.clone(),
1135 overloads: vec![],
1136 },
1137 );
1138 Ok((preds, t, typed))
1139 } else {
1140 let mut overloads = Vec::new();
1141 for scheme in schemes {
1142 if !scheme.preds.is_empty() {
1143 return Err(TypeError::AmbiguousOverload(var.name.clone()));
1144 }
1145
1146 let scheme = apply_scheme_with_unifier(scheme, unifier);
1147 let (preds, typ) = instantiate(&scheme, supply);
1148 if !preds.is_empty() {
1149 return Err(TypeError::AmbiguousOverload(var.name.clone()));
1150 }
1151 overloads.push(typ);
1152 }
1153 let t = Type::var(supply.fresh(Some(var.name.clone())));
1154 let typed = TypedExpr::new(
1155 t.clone(),
1156 TypedExprKind::Var {
1157 name: var.name.clone(),
1158 overloads,
1159 },
1160 );
1161 Ok((vec![], t, typed))
1162 }
1163 }
1164 Expr::Lam(..) => {
1165 let (params, constraints, body) = collect_lambda_chain(expr);
1166 let mut ann_vars = HashMap::new();
1167 let mut param_tys = Vec::with_capacity(params.len());
1168 for (name, ann) in ¶ms {
1169 let param_ty = match ann {
1170 Some(ann) => {
1171 type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?
1172 }
1173 None => Type::var(supply.fresh(Some(name.clone()))),
1174 };
1175 param_tys.push((name.clone(), param_ty));
1176 }
1177
1178 let mut env1 = env.clone();
1179 let mut known_body = known.clone();
1180 for (name, param_ty) in ¶m_tys {
1181 env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
1182 known_body.remove(name);
1183 }
1184
1185 let (mut preds, body_ty, typed_body) =
1186 infer_expr(unifier, supply, &env1, adts, &known_body, body)?;
1187 let constraint_preds =
1188 predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
1189 preds.extend(constraint_preds);
1190
1191 let mut typed = typed_body;
1192 let mut fun_ty = unifier.apply_type(&body_ty);
1193 for (name, param_ty) in param_tys.iter().rev() {
1194 fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
1195 typed = TypedExpr::new(
1196 fun_ty.clone(),
1197 TypedExprKind::Lam {
1198 param: name.clone(),
1199 body: Box::new(typed),
1200 },
1201 );
1202 }
1203
1204 Ok((preds, fun_ty, typed))
1205 }
1206 Expr::App(..) => {
1207 let (head, args) = collect_app_chain(expr);
1208 let (mut preds, mut func_ty, mut typed) =
1209 infer_expr(unifier, supply, env, adts, known, head)?;
1210 let mut overload_name = None;
1211 let mut overload_candidates = match &typed.kind {
1212 TypedExprKind::Var { name, overloads } if !overloads.is_empty() => {
1213 overload_name = Some(name.clone());
1214 Some(overloads.clone())
1215 }
1216 _ => None,
1217 };
1218 for arg in args {
1219 let expected_arg = match unifier.apply_type(&func_ty).as_ref() {
1220 TypeKind::Fun(arg, _) => Some(arg.clone()),
1221 _ => None,
1222 };
1223 let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
1224 TypeKind::Fun(arg, _) => Some(arg.clone()),
1225 _ => None,
1226 };
1227 let (p_arg, arg_ty, typed_arg) =
1228 infer_app_arg_typed(unifier, supply, env, adts, known, arg_hint, arg)?;
1229 let mut arg_ty = unifier.apply_type(&arg_ty);
1230 let mut typed_arg = typed_arg;
1231
1232 if let Some(expected_arg) = expected_arg {
1233 let expected_arg = unifier.apply_type(&expected_arg);
1234 if let (Some(expected_elem), Some(arg_elem)) = (
1235 unary_app_arg(&expected_arg, "Array"),
1236 unary_app_arg(&arg_ty, "List"),
1237 ) {
1238 unifier.unify(&expected_elem, &arg_elem)?;
1239 let elem_ty = unifier.apply_type(&expected_elem);
1240 let list_ty = Type::list(elem_ty.clone());
1241 let array_ty = Type::array(elem_ty);
1242 let coercion_ty = Type::fun(list_ty, array_ty.clone());
1243 let coercion_fn = TypedExpr::new(
1244 coercion_ty,
1245 TypedExprKind::Var {
1246 name: sym("prim_array_from_list"),
1247 overloads: vec![],
1248 },
1249 );
1250 typed_arg = TypedExpr::new(
1251 array_ty.clone(),
1252 TypedExprKind::App(Box::new(coercion_fn), Box::new(typed_arg)),
1253 );
1254 arg_ty = array_ty;
1255 }
1256 }
1257 if let Some(candidates) = overload_candidates.take() {
1258 let candidates = candidates
1259 .into_iter()
1260 .map(|t| unifier.apply_type(&t))
1261 .collect::<Vec<_>>();
1262 let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
1263 if narrowed.is_empty()
1264 && let Some(name) = &overload_name
1265 {
1266 return Err(TypeError::AmbiguousOverload(name.clone()));
1267 }
1268 overload_candidates = Some(narrowed);
1269 }
1270 let res_ty = match overload_candidates.as_ref() {
1271 Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
1272 _ => Type::var(supply.fresh(Some("r".into()))),
1273 };
1274 unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
1275 let result_ty = match overload_candidates.as_ref() {
1276 Some(candidates) if candidates.len() == 1 => {
1277 unifier.apply_type(&candidates[0])
1278 }
1279 _ => unifier.apply_type(&res_ty),
1280 };
1281 preds.extend(p_arg);
1282 typed = TypedExpr::new(
1283 result_ty.clone(),
1284 TypedExprKind::App(Box::new(typed), Box::new(typed_arg)),
1285 );
1286 func_ty = result_ty;
1287 }
1288 Ok((preds, func_ty, typed))
1289 }
1290 Expr::Project(_, base, field) => {
1291 let (p1, t1, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
1292 let base_ty = unifier.apply_type(&t1);
1293 let known_variant =
1294 known_variant_from_expr_with_known(base, &base_ty, adts, known);
1295 let field_ty =
1296 resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
1297 let typed = TypedExpr::new(
1298 field_ty.clone(),
1299 TypedExprKind::Project {
1300 expr: Box::new(typed_base),
1301 field: field.clone(),
1302 },
1303 );
1304 Ok((p1, field_ty, typed))
1305 }
1306 Expr::RecordUpdate(_, base, updates) => {
1307 let (p_base, t_base, typed_base) =
1308 infer_expr(unifier, supply, env, adts, known, base)?;
1309 let base_ty = unifier.apply_type(&t_base);
1310 let known_variant =
1311 known_variant_from_expr_with_known(base, &base_ty, adts, known);
1312 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
1313 let (result_ty, fields) = resolve_record_update(
1314 unifier,
1315 supply,
1316 adts,
1317 &base_ty,
1318 known_variant,
1319 &update_fields,
1320 )?;
1321 let expected: HashMap<_, _> = fields.into_iter().collect();
1322
1323 let mut preds = p_base;
1324 let mut typed_updates = BTreeMap::new();
1325 for (k, v) in updates {
1326 let expected_ty =
1327 expected.get(k).ok_or_else(|| TypeError::UnknownField {
1328 field: k.clone(),
1329 typ: result_ty.to_string(),
1330 })?;
1331 let (p1, t1, typed_v) =
1332 infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
1333 unifier.unify(&t1, expected_ty)?;
1334 preds.extend(p1);
1335 typed_updates.insert(k.clone(), typed_v);
1336 }
1337 let typed = TypedExpr::new(
1338 result_ty.clone(),
1339 TypedExprKind::RecordUpdate {
1340 base: Box::new(typed_base),
1341 updates: typed_updates,
1342 },
1343 );
1344 Ok((preds, result_ty, typed))
1345 }
1346 Expr::Let(..) => {
1347 let mut bindings = Vec::new();
1348 let mut cur = expr;
1349 while let Expr::Let(_, v, ann, d, b) = cur {
1350 bindings.push((v.clone(), ann.clone(), d.clone()));
1351 cur = b.as_ref();
1352 }
1353
1354 let mut env_cur = env.clone();
1355 let mut known_cur = known.clone();
1356 let mut typed_defs = Vec::new();
1357 for (v, ann, d) in bindings {
1358 let (p1, t1, typed_def) = if let Some(ref ann_expr) = ann {
1359 let mut ann_vars = HashMap::new();
1360 let ann_ty = type_from_annotation_expr_vars(
1361 adts,
1362 ann_expr,
1363 &mut ann_vars,
1364 supply,
1365 )?;
1366 match d.as_ref() {
1367 Expr::RecordUpdate(_, base, updates) => {
1368 infer_record_update_typed_with_hint(
1369 unifier,
1370 supply,
1371 &env_cur,
1372 adts,
1373 &known_cur,
1374 base.as_ref(),
1375 updates,
1376 &ann_ty,
1377 )?
1378 }
1379 _ => {
1380 let (p1, t1, typed_def) = infer_expr(
1381 unifier, supply, &env_cur, adts, &known_cur, &d,
1382 )?;
1383 unifier.unify(&t1, &ann_ty)?;
1384 (p1, t1, typed_def)
1385 }
1386 }
1387 } else {
1388 infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?
1389 };
1390 let def_ty = unifier.apply_type(&t1);
1391 let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
1392 monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
1393 } else {
1394 let scheme =
1395 generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
1396 reject_ambiguous_scheme(&scheme)?;
1397 scheme
1398 };
1399 env_cur.extend(v.name.clone(), scheme);
1400 if let Some(known_variant) =
1401 known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
1402 {
1403 known_cur.insert(
1404 v.name.clone(),
1405 KnownVariant {
1406 adt: known_variant.adt,
1407 variant: known_variant.variant,
1408 },
1409 );
1410 } else {
1411 known_cur.remove(&v.name);
1412 }
1413 typed_defs.push((v.name.clone(), typed_def));
1414 }
1415
1416 let (p_body, t_body, typed_body) =
1417 infer_expr(unifier, supply, &env_cur, adts, &known_cur, cur)?;
1418
1419 let mut typed = typed_body;
1420 for (name, def) in typed_defs.into_iter().rev() {
1421 typed = TypedExpr::new(
1422 t_body.clone(),
1423 TypedExprKind::Let {
1424 name,
1425 def: Box::new(def),
1426 body: Box::new(typed),
1427 },
1428 );
1429 }
1430 Ok((p_body, t_body, typed))
1431 }
1432 Expr::LetRec(_, bindings, body) => {
1433 let mut env_seed = env.clone();
1434 let mut known_seed = known.clone();
1435 let mut binding_tys = HashMap::new();
1436 for (var, _ann, _def) in bindings {
1437 let tv = Type::var(supply.fresh(Some(var.name.clone())));
1438 env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
1439 known_seed.remove(&var.name);
1440 binding_tys.insert(var.name.clone(), tv);
1441 }
1442
1443 let mut inferred_defs = Vec::with_capacity(bindings.len());
1444 for (var, ann, def) in bindings {
1445 let (preds, def_ty, typed_def) =
1446 infer_expr(unifier, supply, &env_seed, adts, &known_seed, def)?;
1447 if let Some(ann) = ann {
1448 let mut ann_vars = HashMap::new();
1449 let ann_ty =
1450 type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
1451 unifier.unify(&def_ty, &ann_ty)?;
1452 }
1453 let binding_ty = binding_tys
1454 .get(&var.name)
1455 .cloned()
1456 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1457 unifier.unify(&binding_ty, &def_ty)?;
1458 let resolved_ty = unifier.apply_type(&binding_ty);
1459
1460 if let Some(known_variant) =
1461 known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
1462 {
1463 known_seed.insert(
1464 var.name.clone(),
1465 KnownVariant {
1466 adt: known_variant.adt,
1467 variant: known_variant.variant,
1468 },
1469 );
1470 } else {
1471 known_seed.remove(&var.name);
1472 }
1473 inferred_defs.push((var.name.clone(), preds, resolved_ty, typed_def));
1474 }
1475
1476 let mut env_body = env.clone();
1477 let mut typed_bindings = Vec::with_capacity(inferred_defs.len());
1478 for (name, preds, def_ty, typed_def) in inferred_defs {
1479 let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
1480 reject_ambiguous_scheme(&scheme)?;
1481 env_body.extend(name.clone(), scheme);
1482 typed_bindings.push((name, typed_def));
1483 }
1484
1485 let (p_body, t_body, typed_body) =
1486 infer_expr(unifier, supply, &env_body, adts, &known_seed, body)?;
1487 let typed = TypedExpr::new(
1488 t_body.clone(),
1489 TypedExprKind::LetRec {
1490 bindings: typed_bindings,
1491 body: Box::new(typed_body),
1492 },
1493 );
1494 Ok((p_body, t_body, typed))
1495 }
1496 Expr::Ite(_, cond, then_expr, else_expr) => {
1497 let (p1, t1, typed_cond) = infer_expr(unifier, supply, env, adts, known, cond)?;
1498 unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
1499 let (p2, t2, typed_then) =
1500 infer_expr(unifier, supply, env, adts, known, then_expr)?;
1501 let (p3, t3, typed_else) =
1502 infer_expr(unifier, supply, env, adts, known, else_expr)?;
1503 unifier.unify(&t2, &t3)?;
1504 let out_ty = unifier.apply_type(&t2);
1505 let mut preds = p1;
1506 preds.extend(p2);
1507 preds.extend(p3);
1508 let typed = TypedExpr::new(
1509 out_ty.clone(),
1510 TypedExprKind::Ite {
1511 cond: Box::new(typed_cond),
1512 then_expr: Box::new(typed_then),
1513 else_expr: Box::new(typed_else),
1514 },
1515 );
1516 Ok((preds, out_ty, typed))
1517 }
1518 Expr::Tuple(_, elems) => {
1519 let mut preds = Vec::new();
1520 let mut types = Vec::new();
1521 let mut typed_elems = Vec::new();
1522 for elem in elems {
1523 let (p1, t1, typed_elem) =
1524 infer_expr(unifier, supply, env, adts, known, elem)?;
1525 preds.extend(p1);
1526 types.push(unifier.apply_type(&t1));
1527 typed_elems.push(typed_elem);
1528 }
1529 let tuple_ty = Type::tuple(types);
1530 let typed = TypedExpr::new(tuple_ty.clone(), TypedExprKind::Tuple(typed_elems));
1531 Ok((preds, tuple_ty, typed))
1532 }
1533 Expr::List(_, elems) => {
1534 let elem_tv = Type::var(supply.fresh(Some("a".into())));
1535 let mut preds = Vec::new();
1536 let mut typed_elems = Vec::new();
1537 for elem in elems {
1538 let (p1, t1, typed_elem) =
1539 infer_expr(unifier, supply, env, adts, known, elem)?;
1540 unifier.unify(&t1, &elem_tv)?;
1541 preds.extend(p1);
1542 typed_elems.push(typed_elem);
1543 }
1544 let list_ty = Type::app(
1545 Type::builtin(BuiltinTypeId::List),
1546 unifier.apply_type(&elem_tv),
1547 );
1548 let typed = TypedExpr::new(list_ty.clone(), TypedExprKind::List(typed_elems));
1549 Ok((preds, list_ty, typed))
1550 }
1551 Expr::Dict(_, kvs) => {
1552 let elem_tv = Type::var(supply.fresh(Some("v".into())));
1553 let mut preds = Vec::new();
1554 let mut typed_kvs = BTreeMap::new();
1555 for (k, v) in kvs {
1556 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v)?;
1557 unifier.unify(&t1, &elem_tv)?;
1558 preds.extend(p1);
1559 typed_kvs.insert(k.clone(), typed_v);
1560 }
1561 let dict_ty = Type::app(
1562 Type::builtin(BuiltinTypeId::Dict),
1563 unifier.apply_type(&elem_tv),
1564 );
1565 let typed = TypedExpr::new(dict_ty.clone(), TypedExprKind::Dict(typed_kvs));
1566 Ok((preds, dict_ty, typed))
1567 }
1568 Expr::Match(_, scrutinee, arms) => {
1569 let (p1, t1, typed_scrutinee) =
1570 infer_expr(unifier, supply, env, adts, known, scrutinee)?;
1571 let mut preds = p1;
1572 let mut typed_arms = Vec::new();
1573 let res_ty = Type::var(supply.fresh(Some("match".into())));
1574 let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
1575
1576 for (pat, expr) in arms {
1577 let scrutinee_ty = unifier.apply_type(&t1);
1578 let (p_pat, binds) =
1579 infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
1580 preds.extend(p_pat);
1581
1582 let mut env_arm = env.clone();
1583 for (name, ty) in binds {
1584 env_arm
1585 .extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1586 }
1587 let mut known_arm = known.clone();
1588 if let Expr::Var(var) = scrutinee.as_ref() {
1589 match pat {
1590 Pattern::Named(_, name, _) => {
1591 let name_sym = name.to_dotted_symbol();
1592 if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1593 known_arm.insert(
1594 var.name.clone(),
1595 KnownVariant {
1596 adt: adt.name.clone(),
1597 variant: name_sym,
1598 },
1599 );
1600 } else {
1601 known_arm.remove(&var.name);
1602 }
1603 }
1604 _ => {
1605 known_arm.remove(&var.name);
1606 }
1607 }
1608 }
1609 let (p_expr, t_expr, typed_expr) =
1610 infer_expr(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1611 unifier.unify(&res_ty, &t_expr)?;
1612 preds.extend(p_expr);
1613 typed_arms.push((pat.clone(), typed_expr));
1614 }
1615
1616 let scrutinee_ty = unifier.apply_type(&t1);
1617 check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1618 let out_ty = unifier.apply_type(&res_ty);
1619 let typed = TypedExpr::new(
1620 out_ty.clone(),
1621 TypedExprKind::Match {
1622 scrutinee: Box::new(typed_scrutinee),
1623 arms: typed_arms,
1624 },
1625 );
1626 Ok((preds, out_ty, typed))
1627 }
1628 Expr::Ann(_, expr, ann) => {
1629 let ann_ty = type_from_annotation_expr(adts, ann)?;
1630 match expr.as_ref() {
1631 Expr::RecordUpdate(_, base, updates) => {
1632 infer_record_update_typed_with_hint(
1633 unifier,
1634 supply,
1635 env,
1636 adts,
1637 known,
1638 base.as_ref(),
1639 updates,
1640 &ann_ty,
1641 )
1642 }
1643 _ => {
1644 let (preds, expr_ty, typed_expr) =
1645 infer_expr(unifier, supply, env, adts, known, expr)?;
1646 unifier.unify(&expr_ty, &ann_ty)?;
1647 let out_ty = unifier.apply_type(&ann_ty);
1648 Ok((preds, out_ty, typed_expr))
1649 }
1650 }
1651 }
1652 }
1653 })()
1654 });
1655 res.map_err(|err| with_span(&span, err))
1656}
1657
1658fn ctor_lookup<'a>(
1659 adts: &'a HashMap<Symbol, AdtDecl>,
1660 name: &Symbol,
1661) -> Option<(&'a AdtDecl, &'a AdtVariant)> {
1662 let mut found = None;
1663 for adt in adts.values() {
1664 if let Some(variant) = adt.variants.iter().find(|v| &v.name == name) {
1665 if found.is_some() {
1666 return None;
1667 }
1668 found = Some((adt, variant));
1669 }
1670 }
1671 found
1672}
1673
1674fn record_fields(variant: &AdtVariant) -> Option<&[(Symbol, Type)]> {
1675 if variant.args.len() != 1 {
1676 return None;
1677 }
1678 match variant.args[0].as_ref() {
1679 TypeKind::Record(fields) => Some(fields),
1680 _ => None,
1681 }
1682}
1683
1684fn instantiate_variant_fields(
1685 adt: &AdtDecl,
1686 variant: &AdtVariant,
1687 supply: &mut TypeVarSupply,
1688) -> Option<(Type, Vec<(Symbol, Type)>)> {
1689 let fields = record_fields(variant)?;
1690 let mut subst = Subst::new_sync();
1691 for param in &adt.params {
1692 let fresh = Type::var(supply.fresh(param.var.name.clone()));
1693 subst = subst.insert(param.var.id, fresh);
1694 }
1695 let result_ty = adt.result_type().apply(&subst);
1696 let fields = fields
1697 .iter()
1698 .map(|(name, ty)| (name.clone(), ty.apply(&subst)))
1699 .collect();
1700 Some((result_ty, fields))
1701}
1702
1703fn known_variant_from_expr(
1704 expr: &Expr,
1705 expr_ty: &Type,
1706 adts: &HashMap<Symbol, AdtDecl>,
1707) -> Option<KnownVariant> {
1708 let mut expr = expr;
1709 while let Expr::Ann(_, inner, _) = expr {
1710 expr = inner.as_ref();
1711 }
1712 if matches!(expr_ty.as_ref(), TypeKind::Fun(..)) {
1713 return None;
1714 }
1715 let ctor = match expr {
1716 Expr::App(_, f, _) => match f.as_ref() {
1717 Expr::Var(var) => var.name.clone(),
1718 _ => return None,
1719 },
1720 _ => return None,
1721 };
1722 let (adt, variant) = ctor_lookup(adts, &ctor)?;
1723 record_fields(variant)?;
1724 Some(KnownVariant {
1725 adt: adt.name.clone(),
1726 variant: variant.name.clone(),
1727 })
1728}
1729
1730fn known_variant_from_expr_with_known(
1731 expr: &Expr,
1732 expr_ty: &Type,
1733 adts: &HashMap<Symbol, AdtDecl>,
1734 known: &KnownVariants,
1735) -> Option<KnownVariant> {
1736 let mut expr = expr;
1737 while let Expr::Ann(_, inner, _) = expr {
1738 expr = inner.as_ref();
1739 }
1740 match expr {
1741 Expr::Var(var) => known.get(&var.name).cloned(),
1742 Expr::RecordUpdate(_, base, _) => {
1743 known_variant_from_expr_with_known(base.as_ref(), expr_ty, adts, known)
1744 }
1745 _ => known_variant_from_expr(expr, expr_ty, adts),
1746 }
1747}
1748
1749fn select_record_variant<'a, F>(
1750 adts: &'a HashMap<Symbol, AdtDecl>,
1751 base_ty: &Type,
1752 known_variant: Option<KnownVariant>,
1753 field_for_errors: &Symbol,
1754 matches_fields: F,
1755) -> Result<(&'a AdtDecl, &'a AdtVariant), TypeError>
1756where
1757 F: Fn(&[(Symbol, Type)]) -> bool,
1758{
1759 if let Some(info) = known_variant {
1760 let adt = adts
1761 .get(&info.adt)
1762 .ok_or_else(|| TypeError::UnknownTypeName(info.adt.clone()))?;
1763 let variant = adt
1764 .variants
1765 .iter()
1766 .find(|v| v.name == info.variant)
1767 .ok_or_else(|| TypeError::UnknownField {
1768 field: field_for_errors.clone(),
1769 typ: base_ty.to_string(),
1770 })?;
1771 return Ok((adt, variant));
1772 }
1773
1774 if let Some(adt_name) = type_head_name(base_ty) {
1775 let adt = adts.get(adt_name).ok_or_else(|| TypeError::UnknownField {
1776 field: field_for_errors.clone(),
1777 typ: base_ty.to_string(),
1778 })?;
1779 if adt.variants.len() == 1 {
1780 return Ok((adt, &adt.variants[0]));
1781 }
1782 return Err(TypeError::FieldNotKnown {
1783 field: field_for_errors.clone(),
1784 typ: base_ty.to_string(),
1785 });
1786 }
1787
1788 if matches!(base_ty.as_ref(), TypeKind::Var(_)) {
1789 let mut candidates = Vec::new();
1790 for adt in adts.values() {
1791 if adt.variants.len() != 1 {
1792 continue;
1793 }
1794 let variant = &adt.variants[0];
1795 let Some(fields) = record_fields(variant) else {
1796 continue;
1797 };
1798 if matches_fields(fields) {
1799 candidates.push((adt, variant));
1800 }
1801 }
1802 if candidates.len() == 1 {
1803 return Ok(candidates.remove(0));
1804 }
1805 if candidates.is_empty() {
1806 return Err(TypeError::UnknownField {
1807 field: field_for_errors.clone(),
1808 typ: base_ty.to_string(),
1809 });
1810 }
1811 return Err(TypeError::FieldNotKnown {
1812 field: field_for_errors.clone(),
1813 typ: base_ty.to_string(),
1814 });
1815 }
1816
1817 Err(TypeError::UnknownField {
1818 field: field_for_errors.clone(),
1819 typ: base_ty.to_string(),
1820 })
1821}
1822
1823fn resolve_record_update(
1824 unifier: &mut Unifier<'_>,
1825 supply: &mut TypeVarSupply,
1826 adts: &HashMap<Symbol, AdtDecl>,
1827 base_ty: &Type,
1828 known_variant: Option<KnownVariant>,
1829 update_fields: &[Symbol],
1830) -> Result<(Type, Vec<(Symbol, Type)>), TypeError> {
1831 if let TypeKind::Record(fields) = base_ty.as_ref() {
1832 return Ok((base_ty.clone(), fields.clone()));
1833 }
1834
1835 let field_for_errors = update_fields.first().cloned().unwrap_or_else(|| sym("_"));
1836
1837 let (adt, variant) =
1838 select_record_variant(adts, base_ty, known_variant, &field_for_errors, |fields| {
1839 update_fields
1840 .iter()
1841 .all(|field| fields.iter().any(|(name, _)| name == field))
1842 })?;
1843
1844 let (result_ty, fields) =
1845 instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1846 TypeError::UnknownField {
1847 field: field_for_errors.clone(),
1848 typ: base_ty.to_string(),
1849 }
1850 })?;
1851
1852 for field in update_fields {
1853 if fields.iter().all(|(name, _)| name != field) {
1854 return Err(TypeError::UnknownField {
1855 field: field.clone(),
1856 typ: base_ty.to_string(),
1857 });
1858 }
1859 }
1860
1861 unifier.unify(base_ty, &result_ty)?;
1862 let result_ty = unifier.apply_type(&result_ty);
1863 let fields = fields
1864 .into_iter()
1865 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
1866 .collect();
1867 Ok((result_ty, fields))
1868}
1869
1870fn resolve_projection(
1871 unifier: &mut Unifier<'_>,
1872 supply: &mut TypeVarSupply,
1873 adts: &HashMap<Symbol, AdtDecl>,
1874 base_ty: &Type,
1875 known_variant: Option<KnownVariant>,
1876 field: &Symbol,
1877) -> Result<Type, TypeError> {
1878 if let Ok(index) = field.as_ref().parse::<usize>() {
1879 let elem_ty = match base_ty.as_ref() {
1880 TypeKind::Tuple(elems) => {
1881 elems
1882 .get(index)
1883 .cloned()
1884 .ok_or_else(|| TypeError::UnknownField {
1885 field: field.clone(),
1886 typ: base_ty.to_string(),
1887 })?
1888 }
1889 TypeKind::Var(_) => {
1890 let mut elems = Vec::with_capacity(index + 1);
1891 for _ in 0..=index {
1892 elems.push(Type::var(supply.fresh(Some(sym("t")))));
1893 }
1894 let tuple_ty = Type::tuple(elems.clone());
1895 unifier.unify(base_ty, &tuple_ty)?;
1896 elems[index].clone()
1897 }
1898 _ => {
1899 return Err(TypeError::UnknownField {
1900 field: field.clone(),
1901 typ: base_ty.to_string(),
1902 });
1903 }
1904 };
1905 return Ok(unifier.apply_type(&elem_ty));
1906 }
1907
1908 let (adt, variant) = select_record_variant(adts, base_ty, known_variant, field, |fields| {
1909 fields.iter().any(|(name, _)| name == field)
1910 })?;
1911
1912 let (result_ty, fields) =
1913 instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1914 TypeError::UnknownField {
1915 field: field.clone(),
1916 typ: base_ty.to_string(),
1917 }
1918 })?;
1919 let field_ty = fields
1920 .iter()
1921 .find(|(name, _)| name == field)
1922 .map(|(_, ty)| ty.clone())
1923 .ok_or_else(|| TypeError::UnknownField {
1924 field: field.clone(),
1925 typ: base_ty.to_string(),
1926 })?;
1927 unifier.unify(base_ty, &result_ty)?;
1928 Ok(unifier.apply_type(&field_ty))
1929}
1930
1931fn decompose_fun(typ: &Type, arity: usize) -> Option<(Vec<Type>, Type)> {
1932 let mut args = Vec::with_capacity(arity);
1933 let mut cur = typ.clone();
1934 for _ in 0..arity {
1935 match cur.as_ref() {
1936 TypeKind::Fun(a, b) => {
1937 args.push(a.clone());
1938 cur = b.clone();
1939 }
1940 _ => return None,
1941 }
1942 }
1943 Some((args, cur))
1944}
1945
1946type InferPatternResult = (Vec<Predicate>, Vec<(Symbol, Type)>);
1947
1948fn infer_pattern(
1949 unifier: &mut Unifier<'_>,
1950 supply: &mut TypeVarSupply,
1951 env: &TypeEnv,
1952 pat: &Pattern,
1953 scrutinee_ty: &Type,
1954) -> Result<InferPatternResult, TypeError> {
1955 let span = *pat.span();
1956 let res = (|| {
1957 unifier.charge_infer_node()?;
1958 match pat {
1959 Pattern::Wildcard(..) => Ok((vec![], vec![])),
1960 Pattern::Var(var) => Ok((
1961 vec![],
1962 vec![(var.name.clone(), unifier.apply_type(scrutinee_ty))],
1963 )),
1964 Pattern::Named(_, name, ps) => {
1965 let ctor_name = name.to_dotted_symbol();
1966 let schemes = env
1967 .lookup(&ctor_name)
1968 .ok_or_else(|| TypeError::UnknownVar(ctor_name.clone()))?;
1969 if schemes.len() != 1 {
1970 return Err(TypeError::AmbiguousOverload(ctor_name));
1971 }
1972 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1973 let (preds, ctor_ty) = instantiate(&scheme, supply);
1974 let (arg_tys, res_ty) = decompose_fun(&ctor_ty, ps.len())
1975 .ok_or(TypeError::UnsupportedExpr("pattern constructor"))?;
1976 unifier.unify(&res_ty, scrutinee_ty)?;
1977 let mut all_preds = preds;
1978 let mut bindings = Vec::new();
1979 for (p, arg_ty) in ps.iter().zip(arg_tys.iter()) {
1980 let arg_ty = unifier.apply_type(arg_ty);
1981 let (p1, binds1) = infer_pattern(unifier, supply, env, p, &arg_ty)?;
1982 all_preds.extend(p1);
1983 bindings.extend(binds1);
1984 }
1985 let bindings = bindings
1986 .into_iter()
1987 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
1988 .collect();
1989 Ok((all_preds, bindings))
1990 }
1991 Pattern::List(_, ps) => {
1992 let elem_tv = Type::var(supply.fresh(Some("a".into())));
1993 let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
1994 unifier.unify(scrutinee_ty, &list_ty)?;
1995 let mut preds = Vec::new();
1996 let mut bindings = Vec::new();
1997 for p in ps {
1998 let elem_ty = unifier.apply_type(&elem_tv);
1999 let (p1, binds1) = infer_pattern(unifier, supply, env, p, &elem_ty)?;
2000 preds.extend(p1);
2001 bindings.extend(binds1);
2002 }
2003 let bindings = bindings
2004 .into_iter()
2005 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2006 .collect();
2007 Ok((preds, bindings))
2008 }
2009 Pattern::Cons(_, head, tail) => {
2010 let elem_tv = Type::var(supply.fresh(Some("a".into())));
2011 let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2012 unifier.unify(scrutinee_ty, &list_ty)?;
2013 let mut preds = Vec::new();
2014 let mut bindings = Vec::new();
2015
2016 let head_ty = unifier.apply_type(&elem_tv);
2017 let (p1, binds1) = infer_pattern(unifier, supply, env, head, &head_ty)?;
2018 preds.extend(p1);
2019 bindings.extend(binds1);
2020
2021 let tail_ty = Type::app(
2022 Type::builtin(BuiltinTypeId::List),
2023 unifier.apply_type(&elem_tv),
2024 );
2025 let (p2, binds2) = infer_pattern(unifier, supply, env, tail, &tail_ty)?;
2026 preds.extend(p2);
2027 bindings.extend(binds2);
2028
2029 let bindings = bindings
2030 .into_iter()
2031 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2032 .collect();
2033 Ok((preds, bindings))
2034 }
2035 Pattern::Tuple(_, elems) => {
2036 let mut elem_tys: Vec<Type> = (0..elems.len())
2037 .map(|i| Type::var(supply.fresh(Some(format!("t{i}").into()))))
2038 .collect();
2039 let expected = Type::tuple(elem_tys.clone());
2040 unifier.unify(scrutinee_ty, &expected)?;
2041 elem_tys = elem_tys
2042 .into_iter()
2043 .map(|t| unifier.apply_type(&t))
2044 .collect();
2045
2046 let mut preds = Vec::new();
2047 let mut bindings = Vec::new();
2048 for (p, ty) in elems.iter().zip(elem_tys.iter()) {
2049 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, p, ty)?;
2050 preds.extend(p_preds);
2051 bindings.extend(p_binds);
2052 }
2053 let bindings = bindings
2054 .into_iter()
2055 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2056 .collect();
2057 Ok((preds, bindings))
2058 }
2059 Pattern::Dict(_, fields) => {
2060 if let TypeKind::Record(ty_fields) = scrutinee_ty.as_ref() {
2061 let mut preds = Vec::new();
2062 let mut bindings = Vec::new();
2063 for (key, pat) in fields {
2064 let ty = ty_fields
2065 .iter()
2066 .find(|(name, _)| name == key)
2067 .map(|(_, ty)| unifier.apply_type(ty))
2068 .ok_or_else(|| TypeError::UnknownField {
2069 field: key.clone(),
2070 typ: scrutinee_ty.to_string(),
2071 })?;
2072 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &ty)?;
2073 preds.extend(p_preds);
2074 bindings.extend(p_binds);
2075 }
2076 let bindings = bindings
2077 .into_iter()
2078 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2079 .collect();
2080 Ok((preds, bindings))
2081 } else {
2082 let elem_tv = Type::var(supply.fresh(Some("v".into())));
2083 let dict_ty = Type::app(Type::builtin(BuiltinTypeId::Dict), elem_tv.clone());
2084 unifier.unify(scrutinee_ty, &dict_ty)?;
2085 let elem_ty = unifier.apply_type(&elem_tv);
2086
2087 let mut preds = Vec::new();
2088 let mut bindings = Vec::new();
2089 for (_key, pat) in fields {
2090 let (p_preds, p_binds) =
2091 infer_pattern(unifier, supply, env, pat, &elem_ty)?;
2092 preds.extend(p_preds);
2093 bindings.extend(p_binds);
2094 }
2095 let bindings = bindings
2096 .into_iter()
2097 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2098 .collect();
2099 Ok((preds, bindings))
2100 }
2101 }
2102 }
2103 })();
2104 res.map_err(|err| with_span(&span, err))
2105}
2106
2107fn type_head_name(typ: &Type) -> Option<&Symbol> {
2108 let mut cur = typ;
2109 while let TypeKind::App(head, _) = cur.as_ref() {
2110 cur = head;
2111 }
2112 match cur.as_ref() {
2113 TypeKind::Con(tc) => Some(&tc.name),
2114 _ => None,
2115 }
2116}
2117
2118fn adt_name_from_patterns(adts: &HashMap<Symbol, AdtDecl>, patterns: &[Pattern]) -> Option<Symbol> {
2119 let mut candidate: Option<Symbol> = None;
2120 for pat in patterns {
2121 let next = match pat {
2122 Pattern::Named(_, name, _) => {
2123 let name_sym = name.to_dotted_symbol();
2124 ctor_lookup(adts, &name_sym).map(|(adt, _)| adt.name.clone())
2125 }
2126 Pattern::List(..) | Pattern::Cons(..) => Some(sym("List")),
2127 _ => None,
2128 };
2129 if let Some(next) = next {
2130 match &candidate {
2131 None => candidate = Some(next),
2132 Some(prev) if *prev == next => {}
2133 Some(_) => return None,
2134 }
2135 }
2136 }
2137 candidate
2138}
2139
2140fn check_match_exhaustive(
2141 adts: &HashMap<Symbol, AdtDecl>,
2142 scrutinee_ty: &Type,
2143 patterns: &[Pattern],
2144) -> Result<(), TypeError> {
2145 if patterns
2146 .iter()
2147 .any(|p| matches!(p, Pattern::Wildcard(..) | Pattern::Var(_)))
2148 {
2149 return Ok(());
2150 }
2151 let adt_name = match type_head_name(scrutinee_ty).cloned() {
2152 Some(name) => name,
2153 None => match adt_name_from_patterns(adts, patterns) {
2154 Some(name) => name,
2155 None => return Ok(()),
2156 },
2157 };
2158 let adt = match adts.get(&adt_name) {
2159 Some(adt) => adt,
2160 None => return Ok(()),
2161 };
2162 let ctor_names: HashSet<Symbol> = adt.variants.iter().map(|v| v.name.clone()).collect();
2163 if ctor_names.is_empty() {
2164 return Ok(());
2165 }
2166 let mut covered = HashSet::new();
2167 for pat in patterns {
2168 match pat {
2169 Pattern::Named(_, name, _) => {
2170 let name_sym = name.to_dotted_symbol();
2171 if ctor_names.contains(&name_sym) {
2172 covered.insert(name_sym);
2173 }
2174 }
2175 Pattern::List(_, elems) if adt_name.as_ref() == "List" && elems.is_empty() => {
2176 covered.insert(sym("Empty"));
2177 }
2178 Pattern::Cons(..) if adt_name.as_ref() == "List" => {
2179 covered.insert(sym("Cons"));
2180 }
2181 _ => {}
2182 }
2183 }
2184 let mut missing: Vec<Symbol> = ctor_names.difference(&covered).cloned().collect();
2185 if missing.is_empty() {
2186 return Ok(());
2187 }
2188 missing.sort();
2189 Err(TypeError::NonExhaustiveMatch {
2190 typ: scrutinee_ty.to_string(),
2191 missing,
2192 })
2193}
2194
2195#[cfg(test)]
2196mod tests {
2197 use super::*;
2198 use rexlang_lexer::Token;
2199 use rexlang_parser::Parser;
2200 use rexlang_util::{GasCosts, GasMeter};
2201
2202 fn tvar(id: TypeVarId, name: &str) -> Type {
2203 Type::var(TypeVar::new(id, Some(sym(name))))
2204 }
2205
2206 fn dict_of(elem: Type) -> Type {
2207 Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
2208 }
2209
2210 #[test]
2211 fn unify_simple() {
2212 let t1 = Type::fun(tvar(0, "a"), Type::builtin(BuiltinTypeId::U32));
2213 let t2 = Type::fun(Type::builtin(BuiltinTypeId::U16), tvar(1, "b"));
2214 let subst = unify(&t1, &t2).unwrap();
2215 assert_eq!(subst.get(&0), Some(&Type::builtin(BuiltinTypeId::U16)));
2216 assert_eq!(subst.get(&1), Some(&Type::builtin(BuiltinTypeId::U32)));
2217 }
2218
2219 #[test]
2220 fn occurs_check_blocks_infinite_type() {
2221 let tv = TypeVar::new(0, Some(sym("a")));
2222 let t = Type::fun(Type::var(tv.clone()), Type::builtin(BuiltinTypeId::U8));
2223 let err = bind(&tv, &t).unwrap_err();
2224 assert!(matches!(err, TypeError::Occurs(_, _)));
2225 }
2226
2227 #[test]
2228 fn instantiate_and_generalize_round_trip() {
2229 let mut supply = TypeVarSupply::new();
2230 let a = Type::var(supply.fresh(Some(sym("a"))));
2231 let scheme = generalize(&TypeEnv::new(), vec![], Type::fun(a.clone(), a.clone()));
2232 let (preds, inst) = instantiate(&scheme, &mut supply);
2233 assert!(preds.is_empty());
2234 if let TypeKind::Fun(l, r) = inst.as_ref() {
2235 match (l.as_ref(), r.as_ref()) {
2236 (TypeKind::Var(_), TypeKind::Var(_)) => {}
2237 _ => panic!("expected polymorphic identity"),
2238 }
2239 } else {
2240 panic!("expected function type");
2241 }
2242 }
2243
2244 #[test]
2245 fn entail_superclasses() {
2246 let ts = TypeSystem::new_with_prelude().unwrap();
2247 let pred = Predicate::new("Semiring", Type::builtin(BuiltinTypeId::I32));
2248 let given = [Predicate::new(
2249 "AdditiveGroup",
2250 Type::builtin(BuiltinTypeId::I32),
2251 )];
2252 assert!(entails(&ts.classes, &given, &pred).unwrap());
2253 }
2254
2255 #[test]
2256 fn entail_instances() {
2257 let ts = TypeSystem::new_with_prelude().unwrap();
2258 let pred = Predicate::new("Field", Type::builtin(BuiltinTypeId::F32));
2259 assert!(entails(&ts.classes, &[], &pred).unwrap());
2260
2261 let pred_fail = Predicate::new("Field", Type::builtin(BuiltinTypeId::U32));
2262 assert!(!entails(&ts.classes, &[], &pred_fail).unwrap());
2263 }
2264
2265 #[test]
2266 fn prelude_injects_functions() {
2267 let ts = TypeSystem::new_with_prelude().unwrap();
2268 let minus = ts.env.lookup(&sym("-")).expect("minus in env");
2269 let div = ts.env.lookup(&sym("/")).expect("div in env");
2270 assert_eq!(minus.len(), 1);
2271 assert_eq!(div.len(), 1);
2272 let minus = &minus[0];
2273 let div = &div[0];
2274 assert_eq!(minus.preds.len(), 1);
2275 assert_eq!(minus.vars.len(), 1);
2276 assert_eq!(div.preds.len(), 1);
2277 assert_eq!(div.vars.len(), 1);
2278 }
2279
2280 #[test]
2281 fn adt_constructors_are_present() {
2282 let ts = TypeSystem::new_with_prelude().unwrap();
2283 assert!(ts.env.lookup(&sym("Empty")).is_some());
2284 assert!(ts.env.lookup(&sym("Cons")).is_some());
2285 assert!(ts.env.lookup(&sym("Ok")).is_some());
2286 assert!(ts.env.lookup(&sym("Err")).is_some());
2287 assert!(ts.env.lookup(&sym("Some")).is_some());
2288 assert!(ts.env.lookup(&sym("None")).is_some());
2289 }
2290
2291 fn parse_expr(code: &str) -> std::sync::Arc<rexlang_ast::expr::Expr> {
2292 let mut parser = Parser::new(Token::tokenize(code).unwrap());
2293 parser.parse_program(&mut GasMeter::default()).unwrap().expr
2294 }
2295
2296 fn parse_program(code: &str) -> rexlang_ast::expr::Program {
2297 let mut parser = Parser::new(Token::tokenize(code).unwrap());
2298 parser.parse_program(&mut GasMeter::default()).unwrap()
2299 }
2300
2301 #[test]
2302 fn infer_deep_list_does_not_overflow() {
2303 const N: usize = 40;
2304 let mut code = String::new();
2305 code.push_str("let xs = ");
2306 for _ in 0..N {
2307 code.push_str("Cons 0 (");
2308 }
2309 code.push_str("Empty");
2310 for _ in 0..N {
2311 code.push(')');
2312 }
2313 code.push_str(" in xs");
2314
2315 let parse_handle = std::thread::Builder::new()
2316 .name("infer_deep_list_parse".into())
2317 .stack_size(128 * 1024 * 1024)
2318 .spawn(move || {
2319 let tokens = Token::tokenize(&code).unwrap();
2320 let mut parser = Parser::new(tokens);
2321 parser.parse_program(&mut GasMeter::default())
2322 })
2323 .unwrap();
2324 let program = parse_handle.join().unwrap().unwrap();
2325 let expr = program.expr;
2326 let mut ts = TypeSystem::new_with_prelude().unwrap();
2327 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2328 assert_eq!(
2329 ty,
2330 Type::app(
2331 Type::builtin(BuiltinTypeId::List),
2332 Type::builtin(BuiltinTypeId::I32)
2333 )
2334 );
2335 }
2336
2337 #[test]
2338 fn collect_adts_in_types_finds_nested_unique_adts() {
2339 let foo = Type::user_con("Foo", 1);
2340 let bar = Type::user_con("Bar", 0);
2341 let ty = Type::fun(
2342 Type::app(
2343 Type::builtin(BuiltinTypeId::List),
2344 Type::app(foo.clone(), tvar(0, "a")),
2345 ),
2346 Type::tuple(vec![
2347 Type::app(foo.clone(), Type::builtin(BuiltinTypeId::I32)),
2348 bar.clone(),
2349 ]),
2350 );
2351
2352 let adts = collect_adts_in_types(vec![ty]).unwrap();
2353 assert_eq!(adts, vec![foo, bar]);
2354 }
2355
2356 #[test]
2357 fn collect_adts_in_types_rejects_conflicting_names() {
2358 let arity1 = Type::user_con("Thing", 1);
2359 let arity2 = Type::user_con("Thing", 2);
2360
2361 let err = collect_adts_in_types(vec![arity1.clone(), arity2.clone()]).unwrap_err();
2362 assert_eq!(err.conflicts.len(), 1);
2363 let conflict = &err.conflicts[0];
2364 assert_eq!(conflict.name, sym("Thing"));
2365 assert_eq!(conflict.definitions, vec![arity1, arity2]);
2366 }
2367
2368 #[test]
2369 fn infer_depth_limit_is_enforced() {
2370 const N: usize = 40;
2371 let mut code = String::new();
2372 code.push_str("let xs = ");
2373 for _ in 0..N {
2374 code.push_str("Cons 0 (");
2375 }
2376 code.push_str("Empty");
2377 for _ in 0..N {
2378 code.push(')');
2379 }
2380 code.push_str(" in xs");
2381
2382 let program = parse_program(&code);
2383 let mut ts = TypeSystem::new_with_prelude().unwrap();
2384 ts.set_limits(TypeSystemLimits {
2385 max_infer_depth: Some(8),
2386 });
2387
2388 let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
2389 assert!(
2390 err.to_string().contains("maximum inference depth exceeded"),
2391 "expected a max-depth inference error, got: {err:?}"
2392 );
2393 }
2394
2395 #[test]
2396 fn declare_fn_injects_scheme_for_use_sites() {
2397 let program = parse_program(
2398 r#"
2399 declare fn id x: a -> a
2400 id 1
2401 "#,
2402 );
2403 let mut ts = TypeSystem::new_with_prelude().unwrap();
2404 ts.register_decls(&program.decls).unwrap();
2405 let (preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2406 assert!(
2407 preds.is_empty()
2408 || preds.iter().all(|p| p.class.as_ref() == "Integral"
2409 && p.typ == Type::builtin(BuiltinTypeId::I32))
2410 );
2411 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2412 }
2413
2414 #[test]
2415 fn declare_fn_is_noop_when_matching_existing_scheme() {
2416 let mut ts = TypeSystem::new_with_prelude().unwrap();
2417 ts.add_value(
2418 "foo",
2419 Scheme::new(
2420 vec![],
2421 vec![],
2422 Type::fun(
2423 Type::builtin(BuiltinTypeId::I32),
2424 Type::builtin(BuiltinTypeId::I32),
2425 ),
2426 ),
2427 );
2428
2429 let program = parse_program(
2430 r#"
2431 declare fn foo x: i32 -> i32
2432 0
2433 "#,
2434 );
2435 let rexlang_ast::expr::Decl::DeclareFn(fd) = &program.decls[0] else {
2436 panic!("expected declare fn decl");
2437 };
2438 ts.inject_declare_fn_decl(fd).unwrap();
2439 }
2440
2441 #[test]
2442 fn unit_type_parses_and_infers() {
2443 let program = parse_program(
2444 r#"
2445 fn unit_id x: () -> () = x
2446 unit_id ()
2447 "#,
2448 );
2449 let mut ts = TypeSystem::new_with_prelude().unwrap();
2450 ts.register_decls(&program.decls).unwrap();
2451 let (preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2452 assert!(preds.is_empty());
2453 assert_eq!(ty, Type::tuple(vec![]));
2454 }
2455
2456 fn strip_span(mut err: TypeError) -> TypeError {
2457 while let TypeError::Spanned { error, .. } = err {
2458 err = *error;
2459 }
2460 err
2461 }
2462
2463 #[test]
2464 fn type_errors_include_span() {
2465 let expr = parse_expr("missing");
2466 let mut ts = TypeSystem::new_with_prelude().unwrap();
2467 let err = infer(&mut ts, expr.as_ref()).unwrap_err();
2468 match err {
2469 TypeError::Spanned { span, error } => {
2470 assert_ne!(span, Span::default());
2471 assert!(matches!(
2472 *error,
2473 TypeError::UnknownVar(name) if name.as_ref() == "missing"
2474 ));
2475 }
2476 other => panic!("expected spanned error, got {other:?}"),
2477 }
2478 }
2479
2480 #[test]
2481 fn infer_with_gas_rejects_out_of_budget() {
2482 let expr = parse_expr("1");
2483 let mut ts = TypeSystem::new_with_prelude().unwrap();
2484 let mut gas = GasMeter::new(
2485 Some(0),
2486 GasCosts {
2487 infer_node: 1,
2488 unify_step: 0,
2489 ..GasCosts::sensible_defaults()
2490 },
2491 );
2492 let err = infer_with_gas(&mut ts, expr.as_ref(), &mut gas).unwrap_err();
2493 assert!(matches!(strip_span(err), TypeError::OutOfGas(..)));
2494 }
2495
2496 #[test]
2497 fn reject_user_redefinition_of_primitive_type_name() {
2498 let program = parse_program("type i32 = I32Wrap i32");
2499 let mut ts = TypeSystem::new_with_prelude().unwrap();
2500 let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2501 panic!("expected type decl");
2502 };
2503 let err = ts.register_type_decl(decl).unwrap_err();
2504 assert!(matches!(
2505 err,
2506 TypeError::ReservedTypeName(name) if name.as_ref() == "i32"
2507 ));
2508 }
2509
2510 #[test]
2511 fn reject_user_redefinition_of_prelude_adt_name() {
2512 let program = parse_program("type Result e a = Nope e a");
2513 let mut ts = TypeSystem::new_with_prelude().unwrap();
2514 let rexlang_ast::expr::Decl::Type(decl) = &program.decls[0] else {
2515 panic!("expected type decl");
2516 };
2517 let err = ts.register_type_decl(decl).unwrap_err();
2518 assert!(matches!(
2519 err,
2520 TypeError::ReservedTypeName(name) if name.as_ref() == "Result"
2521 ));
2522 }
2523
2524 #[test]
2525 fn infer_polymorphic_id_tuple() {
2526 let expr = parse_expr(
2527 r#"
2528 let
2529 id = \x -> x
2530 in
2531 id (id 420, id 6.9, id "str")
2532 "#,
2533 );
2534 let mut ts = TypeSystem::new_with_prelude().unwrap();
2535 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2536 let expected = Type::tuple(vec![
2537 Type::builtin(BuiltinTypeId::I32),
2538 Type::builtin(BuiltinTypeId::F32),
2539 Type::builtin(BuiltinTypeId::String),
2540 ]);
2541 assert_eq!(ty, expected);
2542 }
2543
2544 #[test]
2545 fn infer_type_annotation_ok() {
2546 let expr = parse_expr("let x: i32 = 42 in x");
2547 let mut ts = TypeSystem::new_with_prelude().unwrap();
2548 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2549 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2550 }
2551
2552 #[test]
2553 fn infer_type_annotation_lambda_param() {
2554 let expr = parse_expr("\\ (a : f32) -> a");
2555 let mut ts = TypeSystem::new_with_prelude().unwrap();
2556 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2557 assert_eq!(
2558 ty,
2559 Type::fun(
2560 Type::builtin(BuiltinTypeId::F32),
2561 Type::builtin(BuiltinTypeId::F32)
2562 )
2563 );
2564 }
2565
2566 #[test]
2567 fn infer_type_annotation_is_alias() {
2568 let expr = parse_expr("\"hi\" is str");
2569 let mut ts = TypeSystem::new_with_prelude().unwrap();
2570 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2571 assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
2572 }
2573
2574 #[test]
2575 fn infer_type_annotation_mismatch_error() {
2576 let expr = parse_expr("let x: i32 = 3.14 in x");
2577 let mut ts = TypeSystem::new_with_prelude().unwrap();
2578 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
2579 assert!(matches!(err, TypeError::Unification(_, _)));
2580 }
2581
2582 #[test]
2583 fn infer_project_single_variant_let() {
2584 let program = parse_program(
2585 r#"
2586 type MyADT = MyVariant1 { field1: i32, field2: f32 }
2587 let
2588 x = MyVariant1 { field1 = 1, field2 = 2.0 }
2589 in
2590 (x.field1, x.field2)
2591 "#,
2592 );
2593 let mut ts = TypeSystem::new_with_prelude().unwrap();
2594 for decl in &program.decls {
2595 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2596 ts.register_type_decl(decl).unwrap();
2597 }
2598 }
2599 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2600 let expected = Type::tuple(vec![
2601 Type::builtin(BuiltinTypeId::I32),
2602 Type::builtin(BuiltinTypeId::F32),
2603 ]);
2604 assert_eq!(ty, expected);
2605 }
2606
2607 #[test]
2608 fn infer_project_known_variant_let() {
2609 let program = parse_program(
2610 r#"
2611 type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
2612 let
2613 x = MyVariant1 { field1 = 1, field2 = 2.0 }
2614 in
2615 x.field1
2616 "#,
2617 );
2618 let mut ts = TypeSystem::new_with_prelude().unwrap();
2619 for decl in &program.decls {
2620 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2621 ts.register_type_decl(decl).unwrap();
2622 }
2623 }
2624 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2625 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2626 }
2627
2628 #[test]
2629 fn infer_project_unknown_variant_error() {
2630 let program = parse_program(
2631 r#"
2632 type MyADT = MyVariant1 { field1: i32, field2: f32 } | MyVariant2 i32 f32
2633 let
2634 x = MyVariant2 1 2.0
2635 in
2636 x.field1
2637 "#,
2638 );
2639 let mut ts = TypeSystem::new_with_prelude().unwrap();
2640 for decl in &program.decls {
2641 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2642 ts.register_type_decl(decl).unwrap();
2643 }
2644 }
2645 let err = strip_span(infer(&mut ts, program.expr.as_ref()).unwrap_err());
2646 assert!(matches!(err, TypeError::FieldNotKnown { .. }));
2647 }
2648
2649 #[test]
2650 fn infer_project_lambda_param_single_variant() {
2651 let program = parse_program(
2652 r#"
2653 type Boxed = Boxed { value: i32 }
2654 let
2655 f = \x -> x.value
2656 in
2657 f (Boxed { value = 1 })
2658 "#,
2659 );
2660 let mut ts = TypeSystem::new_with_prelude().unwrap();
2661 for decl in &program.decls {
2662 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2663 ts.register_type_decl(decl).unwrap();
2664 }
2665 }
2666 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2667 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2668 }
2669
2670 #[test]
2671 fn infer_project_in_match_arm() {
2672 let program = parse_program(
2673 r#"
2674 type MyADT = MyVariant1 { field1: i32 } | MyVariant2 i32
2675 let
2676 x = MyVariant1 { field1 = 1 }
2677 in
2678 match x
2679 when MyVariant1 { field1 } -> x.field1
2680 when MyVariant2 _ -> 0
2681 "#,
2682 );
2683 let mut ts = TypeSystem::new_with_prelude().unwrap();
2684 for decl in &program.decls {
2685 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2686 ts.register_type_decl(decl).unwrap();
2687 }
2688 }
2689 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2690 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2691 }
2692
2693 #[test]
2694 fn infer_nested_let_lambda_match_option() {
2695 let expr = parse_expr(
2696 r#"
2697 let
2698 choose = \flag a b -> if flag then a else b,
2699 build = \flag ->
2700 let
2701 pick = choose flag,
2702 val = pick 1 2
2703 in
2704 Some val
2705 in
2706 match (build true)
2707 when Some x -> x
2708 when None -> 0
2709 "#,
2710 );
2711 let mut ts = TypeSystem::new_with_prelude().unwrap();
2712 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2713 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2714 }
2715
2716 #[test]
2717 fn infer_polymorphic_apply_in_tuple() {
2718 let expr = parse_expr(
2719 r#"
2720 let
2721 apply = \f x -> f x,
2722 id = \x -> x,
2723 wrap = \x -> (x, x)
2724 in
2725 (apply id 1, apply id "hi", apply wrap true)
2726 "#,
2727 );
2728 let mut ts = TypeSystem::new_with_prelude().unwrap();
2729 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2730 let expected = Type::tuple(vec![
2731 Type::builtin(BuiltinTypeId::I32),
2732 Type::builtin(BuiltinTypeId::String),
2733 Type::tuple(vec![
2734 Type::builtin(BuiltinTypeId::Bool),
2735 Type::builtin(BuiltinTypeId::Bool),
2736 ]),
2737 ]);
2738 assert_eq!(ty, expected);
2739 }
2740
2741 #[test]
2742 fn infer_nested_result_option_match() {
2743 let expr = parse_expr(
2744 r#"
2745 let
2746 unwrap = \x ->
2747 match x
2748 when Ok (Some v) -> v
2749 when Ok None -> 0
2750 when Err _ -> 0
2751 in
2752 unwrap (Ok (Some 5))
2753 "#,
2754 );
2755 let mut ts = TypeSystem::new_with_prelude().unwrap();
2756 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2757 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2758 }
2759
2760 #[test]
2761 fn infer_head_or_list_match() {
2762 let expr = parse_expr(
2763 r#"
2764 let
2765 head_or = \fallback xs ->
2766 match xs
2767 when [] -> fallback
2768 when x::xs -> x
2769 in
2770 (head_or 0 [1, 2, 3], head_or 0 [])
2771 "#,
2772 );
2773 let mut ts = TypeSystem::new_with_prelude().unwrap();
2774 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2775 let expected = Type::tuple(vec![
2776 Type::builtin(BuiltinTypeId::I32),
2777 Type::builtin(BuiltinTypeId::I32),
2778 ]);
2779 assert_eq!(ty, expected);
2780 }
2781
2782 #[test]
2783 fn infer_head_or_list_match_cons_constructor_form() {
2784 let expr = parse_expr(
2785 r#"
2786 let
2787 head_or = \fallback xs ->
2788 match xs
2789 when [] -> fallback
2790 when Cons x xs1 -> x
2791 in
2792 (head_or 0 (Cons 1 (Cons 2 Empty)), head_or 0 Empty)
2793 "#,
2794 );
2795 let mut ts = TypeSystem::new_with_prelude().unwrap();
2796 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2797 let expected = Type::tuple(vec![
2798 Type::builtin(BuiltinTypeId::I32),
2799 Type::builtin(BuiltinTypeId::I32),
2800 ]);
2801 assert_eq!(ty, expected);
2802 }
2803
2804 #[test]
2805 fn infer_record_pattern_in_lambda() {
2806 let program = parse_program(
2807 r#"
2808 type Pair = Pair { left: i32, right: i32 }
2809 let
2810 sum = \p ->
2811 match p
2812 when Pair { left, right } -> left + right
2813 in
2814 sum (Pair { left = 1, right = 2 })
2815 "#,
2816 );
2817 let mut ts = TypeSystem::new_with_prelude().unwrap();
2818 for decl in &program.decls {
2819 if let rexlang_ast::expr::Decl::Type(decl) = decl {
2820 ts.register_type_decl(decl).unwrap();
2821 }
2822 }
2823 let (_preds, ty) = infer(&mut ts, program.expr.as_ref()).unwrap();
2824 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2825 }
2826
2827 #[test]
2828 fn infer_fn_decl_simple() {
2829 let program = parse_program(
2830 r#"
2831 fn add (x: i32, y: i32) -> i32 = x + y
2832 add 1 2
2833 "#,
2834 );
2835 let mut ts = TypeSystem::new_with_prelude().unwrap();
2836 let expr = program.expr_with_fns();
2837 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2838 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2839 }
2840
2841 #[test]
2842 fn infer_fn_decl_signature_form() {
2843 let program = parse_program(
2844 r#"
2845 fn add : i32 -> i32 -> i32 = \x y -> x + y
2846 add 1 2
2847 "#,
2848 );
2849 let mut ts = TypeSystem::new_with_prelude().unwrap();
2850 let expr = program.expr_with_fns();
2851 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2852 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2853 }
2854
2855 #[test]
2856 fn infer_fn_decl_polymorphic_where_constraints() {
2857 let program = parse_program(
2858 r#"
2859 fn my_add (x: a, y: a) -> a where AdditiveMonoid a = x + y
2860 (my_add 1 2, my_add 1.0 2.0)
2861 "#,
2862 );
2863 let mut ts = TypeSystem::new_with_prelude().unwrap();
2864 let expr = program.expr_with_fns();
2865 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2866 assert_eq!(
2867 ty,
2868 Type::tuple(vec![
2869 Type::builtin(BuiltinTypeId::I32),
2870 Type::builtin(BuiltinTypeId::F32)
2871 ])
2872 );
2873 }
2874
2875 #[test]
2876 fn infer_additive_monoid_constraint() {
2877 let expr = parse_expr("\\x y -> x + y");
2878 let mut ts = TypeSystem::new_with_prelude().unwrap();
2879 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2880 assert_eq!(preds.len(), 1);
2881 assert_eq!(preds[0].class.as_ref(), "AdditiveMonoid");
2882
2883 if let TypeKind::Fun(a, rest) = ty.as_ref()
2884 && let TypeKind::Fun(b, c) = rest.as_ref()
2885 {
2886 assert_eq!(a.as_ref(), b.as_ref());
2887 assert_eq!(b.as_ref(), c.as_ref());
2888 assert_eq!(preds[0].typ, a.clone());
2889 return;
2890 }
2891 panic!("expected a -> a -> a");
2892 }
2893
2894 #[test]
2895 fn infer_multiplicative_monoid_constraint() {
2896 let expr = parse_expr("\\x y -> x * y");
2897 let mut ts = TypeSystem::new_with_prelude().unwrap();
2898 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2899 assert_eq!(preds.len(), 1);
2900 assert_eq!(preds[0].class.as_ref(), "MultiplicativeMonoid");
2901
2902 if let TypeKind::Fun(a, rest) = ty.as_ref()
2903 && let TypeKind::Fun(b, c) = rest.as_ref()
2904 {
2905 assert_eq!(a.as_ref(), b.as_ref());
2906 assert_eq!(b.as_ref(), c.as_ref());
2907 assert_eq!(preds[0].typ, a.clone());
2908 return;
2909 }
2910 panic!("expected a -> a -> a");
2911 }
2912
2913 #[test]
2914 fn infer_additive_group_constraint() {
2915 let expr = parse_expr("\\x y -> x - y");
2916 let mut ts = TypeSystem::new_with_prelude().unwrap();
2917 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2918 assert_eq!(preds.len(), 1);
2919 assert_eq!(preds[0].class.as_ref(), "AdditiveGroup");
2920
2921 if let TypeKind::Fun(a, rest) = ty.as_ref()
2922 && let TypeKind::Fun(b, c) = rest.as_ref()
2923 {
2924 assert_eq!(a.as_ref(), b.as_ref());
2925 assert_eq!(b.as_ref(), c.as_ref());
2926 assert_eq!(preds[0].typ, a.clone());
2927 return;
2928 }
2929 panic!("expected a -> a -> a");
2930 }
2931
2932 #[test]
2933 fn infer_integral_constraint() {
2934 let expr = parse_expr("\\x y -> x % y");
2935 let mut ts = TypeSystem::new_with_prelude().unwrap();
2936 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2937 assert_eq!(preds.len(), 1);
2938 assert_eq!(preds[0].class.as_ref(), "Integral");
2939
2940 if let TypeKind::Fun(a, rest) = ty.as_ref()
2941 && let TypeKind::Fun(b, c) = rest.as_ref()
2942 {
2943 assert_eq!(a.as_ref(), b.as_ref());
2944 assert_eq!(b.as_ref(), c.as_ref());
2945 assert_eq!(preds[0].typ, a.clone());
2946 return;
2947 }
2948 panic!("expected a -> a -> a");
2949 }
2950
2951 #[test]
2952 fn infer_literal_addition_defaults() {
2953 let expr = parse_expr("1 + 2");
2954 let mut ts = TypeSystem::new_with_prelude().unwrap();
2955 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2956 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2957 assert_eq!(preds.len(), 2);
2958 assert!(preds.iter().any(|p| p.class.as_ref() == "AdditiveMonoid"));
2959 assert!(preds.iter().any(|p| p.class.as_ref() == "Integral"));
2960 assert!(
2961 preds
2962 .iter()
2963 .all(|p| p.typ == Type::builtin(BuiltinTypeId::I32))
2964 );
2965 }
2966
2967 #[test]
2968 fn infer_mod_defaults() {
2969 let expr = parse_expr("1 % 2");
2970 let mut ts = TypeSystem::new_with_prelude().unwrap();
2971 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2972 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2973 assert_eq!(preds.len(), 1);
2974 assert_eq!(preds[0].class.as_ref(), "Integral");
2975 assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::I32));
2976 }
2977
2978 #[test]
2979 fn infer_get_list_type() {
2980 let expr = parse_expr("get 1 [1, 2, 3]");
2981 let mut ts = TypeSystem::new_with_prelude().unwrap();
2982 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2983 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
2984 assert!(preds.iter().any(|p| p.class.as_ref() == "Indexable"));
2985 assert!(preds.iter().all(|p| {
2986 p.class.as_ref() == "Indexable"
2987 || (p.class.as_ref() == "Integral" && p.typ == Type::builtin(BuiltinTypeId::I32))
2988 }));
2989 for pred in preds.iter().filter(|p| p.class.as_ref() == "Indexable") {
2990 assert!(entails(&ts.classes, &[], pred).unwrap());
2991 }
2992 }
2993
2994 #[test]
2995 fn infer_get_tuple_type() {
2996 let expr = parse_expr("(1, 'Hello', true).0");
2997 let mut ts = TypeSystem::new_with_prelude().unwrap();
2998 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
2999 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3000 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3001
3002 let expr = parse_expr("(1, 'Hello', true).1");
3003 let mut ts = TypeSystem::new_with_prelude().unwrap();
3004 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3005 assert_eq!(ty, Type::builtin(BuiltinTypeId::String));
3006 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3007
3008 let expr = parse_expr("(1, 'Hello', true).2");
3009 let mut ts = TypeSystem::new_with_prelude().unwrap();
3010 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3011 assert_eq!(ty, Type::builtin(BuiltinTypeId::Bool));
3012 assert!(preds.is_empty() || preds.iter().all(|p| p.class.as_ref() == "Integral"));
3013 }
3014
3015 #[test]
3016 fn infer_division_defaults() {
3017 let expr = parse_expr("1.0 / 2.0");
3018 let mut ts = TypeSystem::new_with_prelude().unwrap();
3019 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3020 assert_eq!(ty, Type::builtin(BuiltinTypeId::F32));
3021 assert_eq!(preds.len(), 1);
3022 assert_eq!(preds[0].class.as_ref(), "Field");
3023 assert_eq!(preds[0].typ, Type::builtin(BuiltinTypeId::F32));
3024 assert!(entails(&ts.classes, &[], &preds[0]).unwrap());
3025 }
3026
3027 #[test]
3028 fn infer_unbound_variable_error() {
3029 let expr = parse_expr("missing");
3030 let mut ts = TypeSystem::new_with_prelude().unwrap();
3031 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3032 assert!(matches!(
3033 err,
3034 TypeError::UnknownVar(name) if name.as_ref() == "missing"
3035 ));
3036 }
3037
3038 #[test]
3039 fn infer_if_branch_type_mismatch_error() {
3040 let expr = parse_expr(r#"if true then 1 else "no""#);
3041 let mut ts = TypeSystem::new_with_prelude().unwrap();
3042 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3043 match err {
3044 TypeError::Unification(a, b) => {
3045 let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
3046 assert!(ok, "expected i32 vs string, got {a} vs {b}");
3047 }
3048 other => panic!("expected unification error, got {other:?}"),
3049 }
3050 }
3051
3052 #[test]
3053 fn infer_unknown_pattern_constructor_error() {
3054 let expr = parse_expr("match 1 when Nope -> 1");
3055 let mut ts = TypeSystem::new_with_prelude().unwrap();
3056 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3057 assert!(matches!(
3058 err,
3059 TypeError::UnknownVar(name) if name.as_ref() == "Nope"
3060 ));
3061 }
3062
3063 #[test]
3064 fn infer_ambiguous_overload_error() {
3065 let mut ts = TypeSystem::new();
3066 let a = TypeVar::new(0, Some(sym("a")));
3067 let b = TypeVar::new(1, Some(sym("b")));
3068 let scheme_a = Scheme::new(vec![a.clone()], vec![], Type::var(a));
3069 let scheme_b = Scheme::new(vec![b.clone()], vec![], Type::var(b));
3070 ts.add_overload(sym("dup"), scheme_a);
3071 ts.add_overload(sym("dup"), scheme_b);
3072 let expr = parse_expr("dup");
3073 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3074 assert!(matches!(
3075 err,
3076 TypeError::AmbiguousOverload(name) if name.as_ref() == "dup"
3077 ));
3078 }
3079
3080 #[test]
3081 fn infer_if_cond_not_bool_error() {
3082 let expr = parse_expr("if 1 then 2 else 3");
3083 let mut ts = TypeSystem::new_with_prelude().unwrap();
3084 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3085 match err {
3086 TypeError::Unification(a, b) => {
3087 let ok = (a == "bool" && b == "i32") || (a == "i32" && b == "bool");
3088 assert!(ok, "expected bool vs i32, got {a} vs {b}");
3089 }
3090 other => panic!("expected unification error, got {other:?}"),
3091 }
3092 }
3093
3094 #[test]
3095 fn infer_apply_non_function_error() {
3096 let expr = parse_expr("1 2");
3097 let mut ts = TypeSystem::new_with_prelude().unwrap();
3098 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3099 assert!(matches!(err, TypeError::Unification(_, _)));
3100 }
3101
3102 #[test]
3103 fn infer_list_element_mismatch_error() {
3104 let expr = parse_expr("[1, true]");
3105 let mut ts = TypeSystem::new_with_prelude().unwrap();
3106 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3107 match err {
3108 TypeError::Unification(a, b) => {
3109 let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
3110 assert!(ok, "expected i32 vs bool, got {a} vs {b}");
3111 }
3112 other => panic!("expected unification error, got {other:?}"),
3113 }
3114 }
3115
3116 #[test]
3117 fn infer_dict_value_mismatch_error() {
3118 let expr = parse_expr("{a = 1, b = true}");
3119 let mut ts = TypeSystem::new_with_prelude().unwrap();
3120 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3121 match err {
3122 TypeError::Unification(a, b) => {
3123 let ok = (a == "i32" && b == "bool") || (a == "bool" && b == "i32");
3124 assert!(ok, "expected i32 vs bool, got {a} vs {b}");
3125 }
3126 other => panic!("expected unification error, got {other:?}"),
3127 }
3128 }
3129
3130 #[test]
3131 fn infer_match_list_on_non_list_error() {
3132 let expr = parse_expr("match 1 when [x] -> x");
3133 let mut ts = TypeSystem::new_with_prelude().unwrap();
3134 assert!(infer(&mut ts, expr.as_ref()).is_err());
3135 }
3136
3137 #[test]
3138 fn infer_pattern_constructor_arity_error() {
3139 let expr = parse_expr("match (Ok 1) when Ok x y -> x");
3140 let mut ts = TypeSystem::new_with_prelude().unwrap();
3141 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3142 assert!(matches!(
3143 err,
3144 TypeError::UnsupportedExpr("pattern constructor")
3145 ));
3146 }
3147
3148 #[test]
3149 fn infer_match_arm_type_mismatch_error() {
3150 let expr = parse_expr(r#"match 1 when _ -> 1 when _ -> "no""#);
3151 let mut ts = TypeSystem::new_with_prelude().unwrap();
3152 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3153 match err {
3154 TypeError::Unification(a, b) => {
3155 let ok = (a == "i32" && b == "string") || (a == "string" && b == "i32");
3156 assert!(ok, "expected i32 vs string, got {a} vs {b}");
3157 }
3158 other => panic!("expected unification error, got {other:?}"),
3159 }
3160 }
3161
3162 #[test]
3163 fn infer_match_option_on_non_option_error() {
3164 let expr = parse_expr("match 1 when Some x -> x");
3165 let mut ts = TypeSystem::new_with_prelude().unwrap();
3166 assert!(infer(&mut ts, expr.as_ref()).is_err());
3167 }
3168
3169 #[test]
3170 fn infer_dict_pattern_on_non_dict_error() {
3171 let expr = parse_expr("match 1 when {a} -> a");
3172 let mut ts = TypeSystem::new_with_prelude().unwrap();
3173 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3174 assert!(matches!(err, TypeError::Unification(_, _)));
3175 }
3176
3177 #[test]
3178 fn infer_cons_pattern_on_non_list_error() {
3179 let expr = parse_expr("match 1 when x::xs -> x");
3180 let mut ts = TypeSystem::new_with_prelude().unwrap();
3181 assert!(infer(&mut ts, expr.as_ref()).is_err());
3182 }
3183
3184 #[test]
3185 fn infer_apply_wrong_arg_type_error() {
3186 let expr = parse_expr("(\\x -> x + 1) true");
3187 let mut ts = TypeSystem::new_with_prelude().unwrap();
3188 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3189 assert!(matches!(err, TypeError::Unification(_, _)));
3190 }
3191
3192 #[test]
3193 fn infer_self_application_occurs_error() {
3194 let expr = parse_expr("\\x -> x x");
3195 let mut ts = TypeSystem::new_with_prelude().unwrap();
3196 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3197 assert!(matches!(err, TypeError::Occurs(_, _)));
3198 }
3199
3200 #[test]
3201 fn infer_apply_constructor_too_many_args_error() {
3202 let expr = parse_expr("Some 1 2");
3203 let mut ts = TypeSystem::new_with_prelude().unwrap();
3204 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3205 assert!(matches!(err, TypeError::Unification(_, _)));
3206 }
3207
3208 #[test]
3209 fn infer_operator_type_mismatch_error() {
3210 let expr = parse_expr("1 + true");
3211 let mut ts = TypeSystem::new_with_prelude().unwrap();
3212 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3213 assert!(matches!(err, TypeError::Unification(_, _)));
3214 }
3215
3216 #[test]
3217 fn infer_non_exhaustive_match_is_error() {
3218 let expr = parse_expr("match (Ok 1) when Ok x -> x");
3219 let mut ts = TypeSystem::new_with_prelude().unwrap();
3220 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3221 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3222 }
3223
3224 #[test]
3225 fn infer_non_exhaustive_match_on_bound_var_error() {
3226 let expr = parse_expr("let x = Ok 1 in match x when Ok y -> y");
3227 let mut ts = TypeSystem::new_with_prelude().unwrap();
3228 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3229 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3230 }
3231
3232 #[test]
3233 fn infer_non_exhaustive_match_in_lambda_error() {
3234 let expr = parse_expr("\\x -> match x when Ok y -> y");
3235 let mut ts = TypeSystem::new_with_prelude().unwrap();
3236 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3237 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3238 }
3239
3240 #[test]
3241 fn infer_non_exhaustive_option_match_error() {
3242 let expr = parse_expr("match (Some 1) when Some x -> x");
3243 let mut ts = TypeSystem::new_with_prelude().unwrap();
3244 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3245 match err {
3246 TypeError::NonExhaustiveMatch { missing, .. } => {
3247 assert_eq!(missing, vec![sym("None")]);
3248 }
3249 other => panic!("expected non-exhaustive match, got {other:?}"),
3250 }
3251 }
3252
3253 #[test]
3254 fn infer_non_exhaustive_result_match_error() {
3255 let expr = parse_expr("match (Err 1) when Ok x -> x");
3256 let mut ts = TypeSystem::new_with_prelude().unwrap();
3257 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3258 match err {
3259 TypeError::NonExhaustiveMatch { missing, .. } => {
3260 assert_eq!(missing, vec![sym("Err")]);
3261 }
3262 other => panic!("expected non-exhaustive match, got {other:?}"),
3263 }
3264 }
3265
3266 #[test]
3267 fn infer_non_exhaustive_list_missing_empty_error() {
3268 let expr = parse_expr("match [1, 2] when x::xs -> x");
3269 let mut ts = TypeSystem::new_with_prelude().unwrap();
3270 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3271 match err {
3272 TypeError::NonExhaustiveMatch { missing, .. } => {
3273 assert_eq!(missing, vec![sym("Empty")]);
3274 }
3275 other => panic!("expected non-exhaustive match, got {other:?}"),
3276 }
3277 }
3278
3279 #[test]
3280 fn infer_non_exhaustive_list_match_on_bound_var_error() {
3281 let expr = parse_expr("let xs = [1, 2] in match xs when x::xs -> x");
3282 let mut ts = TypeSystem::new_with_prelude().unwrap();
3283 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3284 assert!(matches!(err, TypeError::NonExhaustiveMatch { .. }));
3285 }
3286
3287 #[test]
3288 fn infer_non_exhaustive_list_missing_cons_error() {
3289 let expr = parse_expr("match [1] when [] -> 0");
3290 let mut ts = TypeSystem::new_with_prelude().unwrap();
3291 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3292 match err {
3293 TypeError::NonExhaustiveMatch { missing, .. } => {
3294 assert_eq!(missing, vec![sym("Cons")]);
3295 }
3296 other => panic!("expected non-exhaustive match, got {other:?}"),
3297 }
3298 }
3299
3300 #[test]
3301 fn infer_match_list_patterns_on_result_error() {
3302 let expr = parse_expr("match (Ok 1) when [] -> 0 when x::xs -> 1");
3303 let mut ts = TypeSystem::new_with_prelude().unwrap();
3304 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3305 assert!(matches!(err, TypeError::Unification(_, _)));
3306 }
3307
3308 #[test]
3309 fn infer_missing_instances_produce_unsatisfied_predicates() {
3310 for (name, code) in [
3311 ("division", "1 / 2"),
3312 ("eq_dict", "{a = 1} == {a = 2}"),
3313 ("min_bool", "min [true]"),
3314 ("map_dict", r#"map (\x -> x) {a = 1}"#),
3315 ] {
3316 let (class, pred_type, expected_ty) = match name {
3317 "division" => (
3318 "Field",
3319 Type::builtin(BuiltinTypeId::I32),
3320 Some(Type::builtin(BuiltinTypeId::I32)),
3321 ),
3322 "eq_dict" => ("Eq", dict_of(Type::builtin(BuiltinTypeId::I32)), None),
3323 "min_bool" => ("Ord", Type::builtin(BuiltinTypeId::Bool), None),
3324 "map_dict" => ("Functor", Type::builtin(BuiltinTypeId::Dict), None),
3325 _ => unreachable!("unknown test case {name}"),
3326 };
3327
3328 let expr = parse_expr(code);
3329 let mut ts = TypeSystem::new_with_prelude().unwrap();
3330 let (preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3331 if let Some(expected) = expected_ty {
3332 assert_eq!(ty, expected, "{name}");
3333 }
3334
3335 let pred = preds
3336 .iter()
3337 .find(|p| p.class.as_ref() == class && p.typ == pred_type)
3338 .unwrap();
3339 assert!(!entails(&ts.classes, &[], pred).unwrap(), "{name}");
3340 }
3341 }
3342
3343 #[test]
3344 fn record_update_single_variant_adt_infers() {
3345 let program = parse_program(
3346 r#"
3347 type Foo = Bar { x: i32, y: i32 }
3348 let
3349 foo: Foo = Bar { x = 1, y = 2 },
3350 bar = { foo with { x = 3 } }
3351 in
3352 bar
3353 "#,
3354 );
3355 let mut ts = TypeSystem::new_with_prelude().unwrap();
3356 ts.register_decls(&program.decls).unwrap();
3357 let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3358 assert_eq!(typ.to_string(), "Foo");
3359 }
3360
3361 #[test]
3362 fn record_update_unknown_field_errors() {
3363 let program = parse_program(
3364 r#"
3365 type Foo = Bar { x: i32 }
3366 let
3367 foo: Foo = Bar { x = 1 }
3368 in
3369 { foo with { y = 2 } }
3370 "#,
3371 );
3372 let mut ts = TypeSystem::new_with_prelude().unwrap();
3373 ts.register_decls(&program.decls).unwrap();
3374 let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
3375 let err = strip_span(err);
3376 assert!(matches!(err, TypeError::UnknownField { .. }));
3377 }
3378
3379 #[test]
3380 fn record_update_requires_refined_variant_for_sum_types() {
3381 let program = parse_program(
3382 r#"
3383 type Foo = Bar { x: i32 } | Baz { x: i32 }
3384 let
3385 f = \ (foo : Foo) -> { foo with { x = 2 } }
3386 in
3387 f (Bar { x = 1 })
3388 "#,
3389 );
3390 let mut ts = TypeSystem::new_with_prelude().unwrap();
3391 ts.register_decls(&program.decls).unwrap();
3392 let err = infer(&mut ts, program.expr.as_ref()).unwrap_err();
3393 let err = strip_span(err);
3394 assert!(matches!(err, TypeError::FieldNotKnown { .. }));
3395 }
3396
3397 #[test]
3398 fn record_update_allowed_after_match_refines_variant() {
3399 let program = parse_program(
3400 r#"
3401 type Foo = Bar { x: i32 } | Baz { x: i32 }
3402 let
3403 f = \ (foo : Foo) ->
3404 match foo
3405 when Bar {x} -> { foo with { x = x + 1 } }
3406 when Baz {x} -> { foo with { x = x + 2 } }
3407 in
3408 f (Bar { x = 1 })
3409 "#,
3410 );
3411 let mut ts = TypeSystem::new_with_prelude().unwrap();
3412 ts.register_decls(&program.decls).unwrap();
3413 let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3414 assert_eq!(typ.to_string(), "Foo");
3415 }
3416
3417 #[test]
3418 fn record_update_plain_record_type() {
3419 let program = parse_program(
3420 r#"
3421 let
3422 f = \ (r : { x: i32, y: i32 }) -> { r with { y = 9 } }
3423 in
3424 f { x = 1, y = 2 }
3425 "#,
3426 );
3427 let mut ts = TypeSystem::new_with_prelude().unwrap();
3428 ts.register_decls(&program.decls).unwrap();
3429 let (_preds, typ) = infer(&mut ts, program.expr.as_ref()).unwrap();
3430 assert_eq!(typ.to_string(), "{x: i32, y: i32}");
3431 }
3432
3433 #[test]
3434 fn infer_typed_hole_expr_is_hole_kind() {
3435 let expr = parse_expr("?");
3436 let mut ts = TypeSystem::new_with_prelude().unwrap();
3437 let (typed, _preds, _ty) = infer_typed(&mut ts, expr.as_ref()).unwrap();
3438 assert!(
3439 matches!(typed.kind, TypedExprKind::Hole),
3440 "typed={typed:#?}"
3441 );
3442 }
3443
3444 #[test]
3445 fn infer_hole_with_annotation_unifies_to_annotation() {
3446 let expr = parse_expr("let x : i32 = ? in x");
3447 let mut ts = TypeSystem::new_with_prelude().unwrap();
3448 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3449 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3450 }
3451
3452 #[test]
3453 fn infer_hole_in_if_condition_is_bool_constrained() {
3454 let expr = parse_expr("if ? then 1 else 2");
3455 let mut ts = TypeSystem::new_with_prelude().unwrap();
3456 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3457 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3458 }
3459
3460 #[test]
3461 fn infer_hole_in_arithmetic_is_numeric_constrained() {
3462 let expr = parse_expr("? + 1");
3463 let mut ts = TypeSystem::new_with_prelude().unwrap();
3464 let (_preds, ty) = infer(&mut ts, expr.as_ref()).unwrap();
3465 assert_eq!(ty, Type::builtin(BuiltinTypeId::I32));
3466 }
3467
3468 #[test]
3469 fn infer_hole_arithmetic_conflicting_annotation_failure() {
3470 let expr = parse_expr("let x : string = (? + 1) in x");
3471 let mut ts = TypeSystem::new_with_prelude().unwrap();
3472 let err = strip_span(infer(&mut ts, expr.as_ref()).unwrap_err());
3473 assert!(matches!(err, TypeError::Unification(_, _)), "err={err:#?}");
3474 }
3475}