1use crate::{
2 error::TypeError,
3 types::{
4 AdtDecl, AdtVariant, BuiltinTypeId, Predicate, Scheme, Type, TypeEnv, TypeKind, TypeVar,
5 TypeVarId, TypedExpr, TypedExprKind, Types,
6 },
7 typesystem::{
8 TypeSystem, TypeVarSupply, instantiate, is_integral_literal_expr,
9 predicates_from_constraints, reject_ambiguous_scheme, type_from_annotation_expr,
10 type_from_annotation_expr_vars,
11 },
12 unification::{Subst, Unifier, compose_subst, subst_is_empty, unify},
13};
14use rex_ast::{Expr, Pattern, Symbol, TypeConstraint, TypeExpr};
15use std::{
16 collections::{BTreeMap, BTreeSet},
17 sync::Arc,
18};
19
20fn dedup_preds(preds: Vec<Predicate>) -> Vec<Predicate> {
21 let mut seen = BTreeSet::new();
22 let mut out = Vec::with_capacity(preds.len());
23 for pred in preds {
24 if seen.insert(pred.clone()) {
25 out.push(pred);
26 }
27 }
28 out
29}
30
31fn is_integral_primitive(typ: &Type) -> bool {
32 matches!(
33 typ.as_ref(),
34 TypeKind::Con(tc)
35 if matches!(
36 tc.builtin_id(),
37 Some(
38 BuiltinTypeId::U8
39 | BuiltinTypeId::U16
40 | BuiltinTypeId::U32
41 | BuiltinTypeId::U64
42 | BuiltinTypeId::I8
43 | BuiltinTypeId::I16
44 | BuiltinTypeId::I32
45 | BuiltinTypeId::I64
46 )
47 )
48 )
49}
50
51fn finalize_infer_for_public_api(
52 mut preds: Vec<Predicate>,
53 mut typ: Type,
54) -> Result<(Vec<Predicate>, Type), TypeError> {
55 let mut subst = Subst::new_sync();
56 for pred in &preds {
57 if pred.class.as_ref() == "Integral"
58 && let TypeKind::Var(tv) = pred.typ.as_ref()
59 {
60 subst = subst.insert(tv.id, Type::builtin(BuiltinTypeId::I32));
61 }
62 }
63
64 if !subst_is_empty(&subst) {
65 preds = dedup_preds(preds.apply(&subst));
66 typ = typ.apply(&subst);
67 }
68
69 for pred in &preds {
70 if pred.class.as_ref() != "Integral" {
71 continue;
72 }
73 if matches!(pred.typ.as_ref(), TypeKind::Var(_)) || is_integral_primitive(&pred.typ) {
74 continue;
75 }
76 return Err(TypeError::Unification("i32".into(), pred.typ.to_string()));
77 }
78
79 Ok((preds, typ))
80}
81
82#[derive(Clone, Debug)]
83struct KnownVariant {
84 adt: Symbol,
85 variant: Symbol,
86}
87
88type KnownVariants = BTreeMap<Symbol, KnownVariant>;
89
90fn apply_scheme_with_unifier(scheme: &Scheme, unifier: &mut Unifier) -> Scheme {
91 let preds = scheme
92 .preds
93 .iter()
94 .map(|pred| Predicate::new(pred.class.clone(), unifier.apply_type(&pred.typ)))
95 .collect();
96 let typ = unifier.apply_type(&scheme.typ);
97 Scheme::new(scheme.vars.clone(), preds, typ)
98}
99
100fn scheme_ftv_with_unifier(scheme: &Scheme, unifier: &mut Unifier) -> BTreeSet<TypeVarId> {
101 let mut ftv = unifier.apply_type(&scheme.typ).ftv();
102 for pred in &scheme.preds {
103 ftv.extend(unifier.apply_type(&pred.typ).ftv());
104 }
105 for var in &scheme.vars {
106 ftv.remove(&var.id);
107 }
108 ftv
109}
110
111fn env_ftv_with_unifier(env: &TypeEnv, unifier: &mut Unifier) -> BTreeSet<TypeVarId> {
112 let mut out = BTreeSet::new();
113 for (_name, schemes) in env.values.iter() {
114 for scheme in schemes {
115 out.extend(scheme_ftv_with_unifier(scheme, unifier));
116 }
117 }
118 out
119}
120
121fn generalize_with_unifier(
122 env: &TypeEnv,
123 preds: Vec<Predicate>,
124 typ: Type,
125 unifier: &mut Unifier,
126) -> Scheme {
127 let preds: Vec<Predicate> = preds
128 .into_iter()
129 .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
130 .collect();
131 let typ = unifier.apply_type(&typ);
132 let mut vars: Vec<TypeVar> = typ
133 .ftv()
134 .union(&preds.ftv())
135 .copied()
136 .collect::<BTreeSet<_>>()
137 .difference(&env_ftv_with_unifier(env, unifier))
138 .cloned()
139 .map(|id| TypeVar::new(id, None))
140 .collect();
141 vars.sort_by_key(|v| v.id);
142 Scheme::new(vars, preds, typ)
143}
144
145fn monomorphic_scheme_with_unifier(
146 preds: Vec<Predicate>,
147 typ: Type,
148 unifier: &mut Unifier,
149) -> Scheme {
150 let preds = dedup_preds(
151 preds
152 .into_iter()
153 .map(|pred| Predicate::new(pred.class, unifier.apply_type(&pred.typ)))
154 .collect(),
155 );
156 let typ = unifier.apply_type(&typ);
157 Scheme::new(vec![], preds, typ)
158}
159
160pub fn infer_typed(
161 type_system: &mut TypeSystem,
162 expr: &Expr,
163) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
164 infer_typed_inner(type_system, expr)
165}
166
167fn infer_typed_inner(
168 type_system: &mut TypeSystem,
169 expr: &Expr,
170) -> Result<(TypedExpr, Vec<Predicate>, Type), TypeError> {
171 let known = KnownVariants::new();
172 let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
173 let (preds, t, typed) = infer_expr(
174 &mut unifier,
175 &mut type_system.supply,
176 &type_system.env,
177 &type_system.adts,
178 &known,
179 expr,
180 )
181 .map_err(|err| err.with_span(expr.span()))?;
182 let subst = unifier.into_subst();
183 let mut typed = typed.apply(&subst);
184 let mut preds = dedup_preds(preds.apply(&subst));
185 let mut t = t.apply(&subst);
186 let improve = improve_indexable(&preds)?;
187 if !subst_is_empty(&improve) {
188 typed = typed.apply(&improve);
189 preds = dedup_preds(preds.apply(&improve));
190 t = t.apply(&improve);
191 }
192 type_system.check_predicate_kinds(&preds)?;
193 Ok((typed, preds, t))
194}
195
196pub fn infer(
197 type_system: &mut TypeSystem,
198 expr: &Expr,
199) -> Result<(Vec<Predicate>, Type), TypeError> {
200 infer_inner(type_system, expr)
201}
202
203fn infer_inner(
204 type_system: &mut TypeSystem,
205 expr: &Expr,
206) -> Result<(Vec<Predicate>, Type), TypeError> {
207 let known = KnownVariants::new();
208 let mut unifier = Unifier::new(type_system.limits.max_infer_depth);
209 let (preds, t) = infer_expr_type(
210 &mut unifier,
211 &mut type_system.supply,
212 &type_system.env,
213 &type_system.adts,
214 &known,
215 expr,
216 )
217 .map_err(|err| err.with_span(expr.span()))?;
218 let subst = unifier.into_subst();
219 let mut preds = dedup_preds(preds.apply(&subst));
220 let mut t = t.apply(&subst);
221 let improve = improve_indexable(&preds)?;
222 if !subst_is_empty(&improve) {
223 preds = dedup_preds(preds.apply(&improve));
224 t = t.apply(&improve);
225 }
226 type_system.check_predicate_kinds(&preds)?;
227 finalize_infer_for_public_api(preds, t)
228}
229
230fn improve_indexable(preds: &[Predicate]) -> Result<Subst, TypeError> {
231 let mut subst = Subst::new_sync();
232 loop {
233 let mut changed = false;
234 for pred in preds {
235 let pred = pred.apply(&subst);
236 if pred.class.as_ref() != "Indexable" {
237 continue;
238 }
239 let TypeKind::Tuple(parts) = pred.typ.as_ref() else {
240 continue;
241 };
242 if parts.len() != 2 {
243 continue;
244 }
245 let container = parts[0].clone();
246 let elem = parts[1].clone();
247 let s = indexable_elem_subst(&container, &elem)?;
248 if !subst_is_empty(&s) {
249 subst = compose_subst(s, subst);
250 changed = true;
251 }
252 }
253 if !changed {
254 break;
255 }
256 }
257 Ok(subst)
258}
259
260fn indexable_elem_subst(container: &Type, elem: &Type) -> Result<Subst, TypeError> {
261 match container.as_ref() {
262 TypeKind::App(head, arg) => match head.as_ref() {
263 TypeKind::Con(tc)
264 if matches!(
265 tc.builtin_id(),
266 Some(BuiltinTypeId::List | BuiltinTypeId::Array)
267 ) =>
268 {
269 unify(elem, arg)
270 }
271 _ => Ok(Subst::new_sync()),
272 },
273 TypeKind::Tuple(elems) => {
274 if elems.is_empty() {
275 return Ok(Subst::new_sync());
276 }
277 let mut subst = Subst::new_sync();
278 let mut cur = elems[0].clone();
279 for ty in elems.iter().skip(1) {
280 let s_next = unify(&cur.apply(&subst), &ty.apply(&subst))?;
281 subst = compose_subst(s_next, subst);
282 cur = cur.apply(&subst);
283 }
284 let elem = elem.apply(&subst);
285 let s_elem = unify(&elem, &cur.apply(&subst))?;
286 Ok(compose_subst(s_elem, subst))
287 }
288 _ => Ok(Subst::new_sync()),
289 }
290}
291
292type LambdaChain<'a> = (
293 Vec<(Symbol, Option<TypeExpr>)>,
294 Vec<TypeConstraint>,
295 &'a Expr,
296);
297
298fn collect_lambda_chain<'a>(expr: &'a Expr) -> LambdaChain<'a> {
299 let mut params = Vec::new();
300 let mut constraints = Vec::new();
301 let mut cur = expr;
302 let mut seen_constraints = false;
303 while let Expr::Lam(_, _scope, param, ann, lam_constraints, body) = cur {
304 if !lam_constraints.is_empty() {
305 if seen_constraints {
306 break;
307 }
308 constraints = lam_constraints.clone();
309 seen_constraints = true;
310 }
311 params.push((param.name.clone(), ann.clone()));
312 cur = body.as_ref();
313 }
314 (params, constraints, cur)
315}
316
317fn collect_app_chain(expr: &Expr) -> (&Expr, Vec<&Expr>) {
318 let mut args = Vec::new();
319 let mut cur = expr;
320 while let Expr::App(_, f, x) = cur {
321 args.push(x.as_ref());
322 cur = f.as_ref();
323 }
324 args.reverse();
325 (cur, args)
326}
327
328fn narrow_overload_candidates(candidates: &[Type], arg_ty: &Type) -> Vec<Type> {
329 let mut out = Vec::new();
330 for candidate in candidates {
331 let Some((params, ret)) = decompose_fun(candidate, 1) else {
332 continue;
333 };
334 let param = ¶ms[0];
335 if let Ok(s) = unify(param, arg_ty) {
336 out.push(ret.apply(&s));
337 }
338 }
339 out
340}
341
342fn unary_app_arg(typ: &Type, ctor_name: &str) -> Option<Type> {
343 let TypeKind::App(head, arg) = typ.as_ref() else {
344 return None;
345 };
346 let TypeKind::Con(tc) = head.as_ref() else {
347 return None;
348 };
349 (tc.name_str() == ctor_name && tc.arity() == 1).then(|| arg.clone())
350}
351
352fn infer_app_arg_type(
353 unifier: &mut Unifier,
354 supply: &mut TypeVarSupply,
355 env: &TypeEnv,
356 adts: &BTreeMap<Symbol, AdtDecl>,
357 known: &KnownVariants,
358 arg_hint: Option<Type>,
359 arg: &Expr,
360) -> Result<(Vec<Predicate>, Type), TypeError> {
361 match (arg_hint, arg) {
362 (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
363 infer_record_update_type_with_hint(
364 unifier,
365 supply,
366 env,
367 adts,
368 known,
369 base.as_ref(),
370 updates,
371 &arg_hint,
372 )
373 }
374 (Some(arg_hint), Expr::Dict(_, kvs))
375 if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
376 {
377 let TypeKind::Record(fields) = arg_hint.as_ref() else {
378 unreachable!("guarded by matches!")
379 };
380 let expected: BTreeMap<_, _> =
381 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
382 let mut seen = BTreeSet::new();
383 let mut preds = Vec::new();
384 for (k, v) in kvs {
385 let expected_ty = expected
386 .get(k)
387 .ok_or_else(|| TypeError::UnknownField {
388 field: k.clone(),
389 typ: Type::record(fields.clone()).to_string(),
390 })?
391 .clone();
392 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
393 unifier.unify(&t1, &expected_ty)?;
394 preds.extend(p1);
395 seen.insert(k.clone());
396 }
397 for key in expected.keys() {
398 if !seen.contains(key.as_ref()) {
399 return Err(TypeError::UnknownField {
400 field: key.clone(),
401 typ: Type::record(fields.clone()).to_string(),
402 });
403 }
404 }
405 let record_ty = Type::record(
406 fields
407 .iter()
408 .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
409 .collect(),
410 );
411 Ok((preds, record_ty))
412 }
413 _ => infer_expr_type(unifier, supply, env, adts, known, arg),
414 }
415}
416
417fn infer_app_arg_typed(
418 unifier: &mut Unifier,
419 supply: &mut TypeVarSupply,
420 env: &TypeEnv,
421 adts: &BTreeMap<Symbol, AdtDecl>,
422 known: &KnownVariants,
423 arg_hint: Option<Type>,
424 arg: &Expr,
425) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
426 match (arg_hint, arg) {
427 (Some(arg_hint), Expr::RecordUpdate(_, base, updates)) => {
428 infer_record_update_typed_with_hint(
429 unifier,
430 supply,
431 env,
432 adts,
433 known,
434 base.as_ref(),
435 updates,
436 &arg_hint,
437 )
438 }
439 (Some(arg_hint), Expr::Dict(_, kvs))
440 if matches!(arg_hint.as_ref(), TypeKind::Record(..)) =>
441 {
442 let TypeKind::Record(fields) = arg_hint.as_ref() else {
443 unreachable!("guarded by matches!")
444 };
445 let mut preds = Vec::new();
446 let mut typed_kvs = BTreeMap::new();
447 let expected: BTreeMap<_, _> =
448 fields.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
449 for (k, v) in kvs {
450 let expected_ty = expected
451 .get(k)
452 .ok_or_else(|| TypeError::UnknownField {
453 field: k.clone(),
454 typ: Type::record(fields.clone()).to_string(),
455 })?
456 .clone();
457 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
458 unifier.unify(&t1, &expected_ty)?;
459 preds.extend(p1);
460 typed_kvs.insert(k.clone(), Arc::new(typed_v));
461 }
462 for key in expected.keys() {
463 if !typed_kvs.contains_key(key.as_ref()) {
464 return Err(TypeError::UnknownField {
465 field: key.clone(),
466 typ: Type::record(fields.clone()).to_string(),
467 });
468 }
469 }
470 let record_ty = Type::record(
471 fields
472 .iter()
473 .map(|(k, v)| (k.clone(), unifier.apply_type(v)))
474 .collect(),
475 );
476 let typed = TypedExpr::new(record_ty.clone(), TypedExprKind::Dict(typed_kvs));
477 Ok((preds, record_ty, typed))
478 }
479 _ => infer_expr(unifier, supply, env, adts, known, arg),
480 }
481}
482
483struct TypedAppState {
484 preds: Vec<Predicate>,
485 func_ty: Type,
486 typed: TypedExpr,
487 overload_name: Option<Symbol>,
488 overload_candidates: Option<Vec<Type>>,
489}
490
491struct TailAppFrame<'a> {
492 head: &'a Expr,
493 prefix_args: Vec<&'a Expr>,
494}
495
496fn app_arg_hint(unifier: &mut Unifier, func_ty: &Type) -> Option<Type> {
497 match unifier.apply_type(func_ty).as_ref() {
498 TypeKind::Fun(arg, _) => Some(arg.clone()),
499 _ => None,
500 }
501}
502
503#[allow(clippy::too_many_arguments)]
504fn infer_typed_app_head(
505 unifier: &mut Unifier,
506 supply: &mut TypeVarSupply,
507 env: &TypeEnv,
508 adts: &BTreeMap<Symbol, AdtDecl>,
509 known: &KnownVariants,
510 head: &Expr,
511) -> Result<TypedAppState, TypeError> {
512 let (preds, func_ty, typed) = infer_expr(unifier, supply, env, adts, known, head)?;
513 let mut overload_name = None;
514 let overload_candidates = match typed.kind.as_ref() {
515 TypedExprKind::Var { name, overloads } if !overloads.is_empty() => {
516 overload_name = Some(name.clone());
517 Some(overloads.clone())
518 }
519 _ => None,
520 };
521 Ok(TypedAppState {
522 preds,
523 func_ty,
524 typed,
525 overload_name,
526 overload_candidates,
527 })
528}
529
530fn apply_typed_app_arg(
531 unifier: &mut Unifier,
532 supply: &mut TypeVarSupply,
533 state: &mut TypedAppState,
534 expected_arg: Option<Type>,
535 p_arg: Vec<Predicate>,
536 arg_ty: Type,
537 typed_arg: TypedExpr,
538) -> Result<(), TypeError> {
539 let mut arg_ty = unifier.apply_type(&arg_ty);
540 let mut typed_arg = typed_arg;
541
542 if let Some(expected_arg) = expected_arg {
543 let expected_arg = unifier.apply_type(&expected_arg);
544 if let (Some(expected_elem), Some(arg_elem)) = (
545 unary_app_arg(&expected_arg, "Array"),
546 unary_app_arg(&arg_ty, "List"),
547 ) {
548 unifier.unify(&expected_elem, &arg_elem)?;
549 let elem_ty = unifier.apply_type(&expected_elem);
550 let list_ty = Type::list(elem_ty.clone());
551 let array_ty = Type::array(elem_ty);
552 let coercion_ty = Type::fun(list_ty, array_ty.clone());
553 let coercion_fn = TypedExpr::new(
554 coercion_ty,
555 TypedExprKind::Var {
556 name: Symbol::intern("prim_array_from_list"),
557 overloads: vec![],
558 },
559 );
560 typed_arg = TypedExpr::new(
561 array_ty.clone(),
562 TypedExprKind::App(Arc::new(coercion_fn), Arc::new(typed_arg)),
563 );
564 arg_ty = array_ty;
565 }
566 }
567 if let Some(candidates) = state.overload_candidates.take() {
568 let candidates = candidates
569 .into_iter()
570 .map(|t| unifier.apply_type(&t))
571 .collect::<Vec<_>>();
572 let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
573 if narrowed.is_empty()
574 && let Some(name) = &state.overload_name
575 {
576 return Err(TypeError::AmbiguousOverload(name.clone()));
577 }
578 state.overload_candidates = Some(narrowed);
579 }
580 let res_ty = match state.overload_candidates.as_ref() {
581 Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
582 _ => Type::var(supply.fresh(Some(Symbol::intern("r")))),
583 };
584 unifier.unify(&state.func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
585 let result_ty = match state.overload_candidates.as_ref() {
586 Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
587 _ => unifier.apply_type(&res_ty),
588 };
589 state.preds.extend(p_arg);
590 state.typed = TypedExpr::new(
591 result_ty.clone(),
592 TypedExprKind::App(Arc::new(state.typed.clone()), Arc::new(typed_arg)),
593 );
594 state.func_ty = result_ty;
595 Ok(())
596}
597
598#[allow(clippy::too_many_arguments)]
599fn infer_typed_app_expr_arg(
600 unifier: &mut Unifier,
601 supply: &mut TypeVarSupply,
602 env: &TypeEnv,
603 adts: &BTreeMap<Symbol, AdtDecl>,
604 known: &KnownVariants,
605 state: &mut TypedAppState,
606 arg: &Expr,
607) -> Result<(), TypeError> {
608 let expected_arg = app_arg_hint(unifier, &state.func_ty);
609 let (p_arg, arg_ty, typed_arg) =
610 infer_app_arg_typed(unifier, supply, env, adts, known, expected_arg.clone(), arg)?;
611 apply_typed_app_arg(
612 unifier,
613 supply,
614 state,
615 expected_arg,
616 p_arg,
617 arg_ty,
618 typed_arg,
619 )
620}
621
622fn collect_tail_app_chain(expr: &Expr) -> Option<(&Expr, Vec<TailAppFrame<'_>>)> {
623 let mut frames = Vec::new();
624 let mut cur = expr;
625 while let Expr::App(..) = cur {
626 let (head, mut args) = collect_app_chain(cur);
627 let Some(tail) = args.pop() else {
628 break;
629 };
630 if !matches!(tail, Expr::App(..)) {
631 break;
632 }
633 frames.push(TailAppFrame {
634 head,
635 prefix_args: args,
636 });
637 cur = tail;
638 }
639 (!frames.is_empty()).then_some((cur, frames))
640}
641
642#[allow(clippy::too_many_arguments)]
643fn infer_tail_app_chain_typed(
644 unifier: &mut Unifier,
645 supply: &mut TypeVarSupply,
646 env: &TypeEnv,
647 adts: &BTreeMap<Symbol, AdtDecl>,
648 known: &KnownVariants,
649 leaf: &Expr,
650 frames: Vec<TailAppFrame<'_>>,
651) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
652 let (mut preds, mut tail_ty, mut typed_tail) =
653 infer_expr(unifier, supply, env, adts, known, leaf)?;
654
655 for frame in frames.into_iter().rev() {
656 let mut state = infer_typed_app_head(unifier, supply, env, adts, known, frame.head)?;
657 for arg in frame.prefix_args {
658 infer_typed_app_expr_arg(unifier, supply, env, adts, known, &mut state, arg)?;
659 }
660 let expected_arg = app_arg_hint(unifier, &state.func_ty);
661 apply_typed_app_arg(
662 unifier,
663 supply,
664 &mut state,
665 expected_arg,
666 Vec::new(),
667 tail_ty,
668 typed_tail,
669 )?;
670 preds.extend(state.preds);
671 tail_ty = state.func_ty;
672 typed_tail = state.typed;
673 }
674
675 Ok((preds, tail_ty, typed_tail))
676}
677
678#[allow(clippy::too_many_arguments)]
679fn infer_record_update_type_with_hint(
680 unifier: &mut Unifier,
681 supply: &mut TypeVarSupply,
682 env: &TypeEnv,
683 adts: &BTreeMap<Symbol, AdtDecl>,
684 known: &KnownVariants,
685 base: &Expr,
686 updates: &BTreeMap<Symbol, Arc<Expr>>,
687 hint_ty: &Type,
688) -> Result<(Vec<Predicate>, Type), TypeError> {
689 let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
690 unifier.unify(&t_base, hint_ty)?;
691 let base_ty = unifier.apply_type(&t_base);
692 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
693 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
694 let (result_ty, fields) = resolve_record_update(
695 unifier,
696 supply,
697 adts,
698 &base_ty,
699 known_variant,
700 &update_fields,
701 )?;
702 let expected: BTreeMap<_, _> = fields.into_iter().collect();
703
704 let mut preds = p_base;
705 for (k, v) in updates {
706 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
707 field: k.clone(),
708 typ: result_ty.to_string(),
709 })?;
710 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
711 unifier.unify(&t1, expected_ty)?;
712 preds.extend(p1);
713 }
714 Ok((preds, result_ty))
715}
716
717#[allow(clippy::too_many_arguments)]
718fn infer_record_update_typed_with_hint(
719 unifier: &mut Unifier,
720 supply: &mut TypeVarSupply,
721 env: &TypeEnv,
722 adts: &BTreeMap<Symbol, AdtDecl>,
723 known: &KnownVariants,
724 base: &Expr,
725 updates: &BTreeMap<Symbol, Arc<Expr>>,
726 hint_ty: &Type,
727) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
728 let (p_base, t_base, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
729 unifier.unify(&t_base, hint_ty)?;
730 let base_ty = unifier.apply_type(&t_base);
731 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
732 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
733 let (result_ty, fields) = resolve_record_update(
734 unifier,
735 supply,
736 adts,
737 &base_ty,
738 known_variant,
739 &update_fields,
740 )?;
741 let expected: BTreeMap<_, _> = fields.into_iter().collect();
742
743 let mut preds = p_base;
744 let mut typed_updates = BTreeMap::new();
745 for (k, v) in updates {
746 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
747 field: k.clone(),
748 typ: result_ty.to_string(),
749 })?;
750 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
751 unifier.unify(&t1, expected_ty)?;
752 preds.extend(p1);
753 typed_updates.insert(k.clone(), Arc::new(typed_v));
754 }
755
756 let typed = TypedExpr::new(
757 result_ty.clone(),
758 TypedExprKind::RecordUpdate {
759 base: Arc::new(typed_base),
760 updates: typed_updates,
761 },
762 );
763 Ok((preds, result_ty, typed))
764}
765
766fn infer_expr_type(
767 unifier: &mut Unifier,
768 supply: &mut TypeVarSupply,
769 env: &TypeEnv,
770 adts: &BTreeMap<Symbol, AdtDecl>,
771 known: &KnownVariants,
772 expr: &Expr,
773) -> Result<(Vec<Predicate>, Type), TypeError> {
774 let span = *expr.span();
775 let res = unifier.with_infer_depth(span, |unifier| {
776 infer_expr_type_inner(unifier, supply, env, adts, known, expr)
777 });
778 res.map_err(|err| err.with_span(&span))
779}
780
781fn infer_expr_type_inner(
782 unifier: &mut Unifier,
783 supply: &mut TypeVarSupply,
784 env: &TypeEnv,
785 adts: &BTreeMap<Symbol, AdtDecl>,
786 known: &KnownVariants,
787 expr: &Expr,
788) -> Result<(Vec<Predicate>, Type), TypeError> {
789 match expr {
790 Expr::Bool(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Bool))),
791 Expr::Uint(_, _) => {
792 let lit_ty = Type::var(supply.fresh(Some(Symbol::intern("n"))));
793 Ok((vec![Predicate::new("Integral", lit_ty.clone())], lit_ty))
794 }
795 Expr::Int(_, _) => {
796 let lit_ty = Type::var(supply.fresh(Some(Symbol::intern("n"))));
797 Ok((
798 vec![
799 Predicate::new("Integral", lit_ty.clone()),
800 Predicate::new("AdditiveGroup", lit_ty.clone()),
801 ],
802 lit_ty,
803 ))
804 }
805 Expr::Float(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::F32))),
806 Expr::String(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::String))),
807 Expr::Uuid(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::Uuid))),
808 Expr::DateTime(_, _) => Ok((vec![], Type::builtin(BuiltinTypeId::DateTime))),
809 Expr::Hole(_) => {
810 let t = Type::var(supply.fresh(Some(Symbol::intern("hole"))));
811 Ok((vec![], t))
812 }
813 Expr::Var(var) => {
814 let schemes = env
815 .lookup(&var.name)
816 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
817 if schemes.len() == 1 {
818 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
819 let (preds, t) = instantiate(&scheme, supply);
820 Ok((preds, t))
821 } else {
822 for scheme in schemes {
823 if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
824 return Err(TypeError::AmbiguousOverload(var.name.clone()));
825 }
826 }
827 let t = Type::var(supply.fresh(Some(var.name.clone())));
828 Ok((vec![], t))
829 }
830 }
831 Expr::Lam(..) => {
832 let (params, constraints, body) = collect_lambda_chain(expr);
833 let mut ann_vars = BTreeMap::new();
834 let mut param_tys = Vec::with_capacity(params.len());
835 for (name, ann) in ¶ms {
836 let param_ty = match ann {
837 Some(ann) => type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?,
838 None => Type::var(supply.fresh(Some(name.clone()))),
839 };
840 param_tys.push((name.clone(), param_ty));
841 }
842
843 let mut env1 = env.clone();
844 let mut known_body = known.clone();
845 for (name, param_ty) in ¶m_tys {
846 env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
847 known_body.remove(name);
848 }
849
850 let (mut preds, body_ty) =
851 infer_expr_type(unifier, supply, &env1, adts, &known_body, body)?;
852 let constraint_preds =
853 predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
854 preds.extend(constraint_preds);
855
856 let mut fun_ty = unifier.apply_type(&body_ty);
857 for (_, param_ty) in param_tys.iter().rev() {
858 fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
859 }
860 Ok((preds, fun_ty))
861 }
862 Expr::App(..) => {
863 let (head, args) = collect_app_chain(expr);
864 let (mut preds, mut func_ty) =
865 infer_expr_type(unifier, supply, env, adts, known, head)?;
866 let mut overload_name = None;
867 let mut overload_candidates = if let Expr::Var(var) = head {
868 if let Some(schemes) = env.lookup(&var.name) {
869 if schemes.len() <= 1 {
870 None
871 } else {
872 let mut candidates = Vec::new();
873 for scheme in schemes {
874 if !scheme.vars.is_empty() || !scheme.preds.is_empty() {
875 return Err(TypeError::AmbiguousOverload(var.name.clone()));
876 }
877 let scheme = apply_scheme_with_unifier(scheme, unifier);
878 let (p, typ) = instantiate(&scheme, supply);
879 if !p.is_empty() {
880 return Err(TypeError::AmbiguousOverload(var.name.clone()));
881 }
882 candidates.push(typ);
883 }
884 overload_name = Some(var.name.clone());
885 Some(candidates)
886 }
887 } else {
888 None
889 }
890 } else {
891 None
892 };
893 for arg in args {
894 let arg_hint = match unifier.apply_type(&func_ty).as_ref() {
895 TypeKind::Fun(arg, _) => Some(arg.clone()),
896 _ => None,
897 };
898 let (p_arg, arg_ty) =
899 infer_app_arg_type(unifier, supply, env, adts, known, arg_hint, arg)?;
900 let arg_ty = unifier.apply_type(&arg_ty);
901 if let Some(candidates) = overload_candidates.take() {
902 let candidates = candidates
903 .into_iter()
904 .map(|t| unifier.apply_type(&t))
905 .collect::<Vec<_>>();
906 let narrowed = narrow_overload_candidates(&candidates, &arg_ty);
907 if narrowed.is_empty()
908 && let Some(name) = &overload_name
909 {
910 return Err(TypeError::AmbiguousOverload(name.clone()));
911 }
912 overload_candidates = Some(narrowed);
913 }
914 let res_ty = match overload_candidates.as_ref() {
915 Some(candidates) if candidates.len() == 1 => candidates[0].clone(),
916 _ => Type::var(supply.fresh(Some(Symbol::intern("r")))),
917 };
918 unifier.unify(&func_ty, &Type::fun(arg_ty, res_ty.clone()))?;
919 preds.extend(p_arg);
920 func_ty = match overload_candidates.as_ref() {
921 Some(candidates) if candidates.len() == 1 => unifier.apply_type(&candidates[0]),
922 _ => unifier.apply_type(&res_ty),
923 };
924 }
925 Ok((preds, func_ty))
926 }
927 Expr::Project(_, base, field) => {
928 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, base)?;
929 let base_ty = unifier.apply_type(&t1);
930 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
931 let field_ty =
932 resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
933 Ok((p1, field_ty))
934 }
935 Expr::RecordUpdate(_, base, updates) => {
936 let (p_base, t_base) = infer_expr_type(unifier, supply, env, adts, known, base)?;
937 let base_ty = unifier.apply_type(&t_base);
938 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
939 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
940 let (result_ty, fields) = resolve_record_update(
941 unifier,
942 supply,
943 adts,
944 &base_ty,
945 known_variant,
946 &update_fields,
947 )?;
948 let expected: BTreeMap<_, _> = fields.into_iter().collect();
949
950 let mut preds = p_base;
951 for (k, v) in updates {
952 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
953 field: k.clone(),
954 typ: result_ty.to_string(),
955 })?;
956 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
957 unifier.unify(&t1, expected_ty)?;
958 preds.extend(p1);
959 }
960 Ok((preds, result_ty))
961 }
962 Expr::Let(..) => {
963 let mut bindings = Vec::new();
964 let mut cur = expr;
965 while let Expr::Let(_, v, ann, d, b) = cur {
966 bindings.push((v.clone(), ann.clone(), d.clone()));
967 cur = b.as_ref();
968 }
969
970 let mut env_cur = env.clone();
971 let mut known_cur = known.clone();
972 for (v, ann, d) in bindings {
973 let (p1, t1) = if let Some(ref ann_expr) = ann {
974 let mut ann_vars = BTreeMap::new();
975 let ann_ty =
976 type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
977 match d.as_ref() {
978 Expr::RecordUpdate(_, base, updates) => infer_record_update_type_with_hint(
979 unifier,
980 supply,
981 &env_cur,
982 adts,
983 &known_cur,
984 base.as_ref(),
985 updates,
986 &ann_ty,
987 )?,
988 _ => {
989 let (p1, t1) =
990 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?;
991 unifier.unify(&t1, &ann_ty)?;
992 (p1, t1)
993 }
994 }
995 } else {
996 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, &d)?
997 };
998 let def_ty = unifier.apply_type(&t1);
999 let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
1000 monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
1001 } else {
1002 let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
1003 reject_ambiguous_scheme(&scheme)?;
1004 scheme
1005 };
1006 env_cur.extend(v.name.clone(), scheme);
1007 if let Some(known_variant) =
1008 known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
1009 {
1010 known_cur.insert(
1011 v.name.clone(),
1012 KnownVariant {
1013 adt: known_variant.adt,
1014 variant: known_variant.variant,
1015 },
1016 );
1017 } else {
1018 known_cur.remove(&v.name);
1019 }
1020 }
1021
1022 let (p_body, t_body) =
1023 infer_expr_type(unifier, supply, &env_cur, adts, &known_cur, cur)?;
1024 Ok((p_body, t_body))
1025 }
1026 Expr::LetRec(_, bindings, body) => {
1027 let mut env_seed = env.clone();
1028 let mut known_seed = known.clone();
1029 let mut binding_tys = BTreeMap::new();
1030 for (var, _ann, _def) in bindings {
1031 let tv = Type::var(supply.fresh(Some(var.name.clone())));
1032 env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
1033 known_seed.remove(&var.name);
1034 binding_tys.insert(var.name.clone(), tv);
1035 }
1036
1037 let mut inferred = Vec::with_capacity(bindings.len());
1038 for (var, ann, def) in bindings {
1039 let (preds, def_ty) =
1040 infer_expr_type(unifier, supply, &env_seed, adts, &known_seed, def)?;
1041 if let Some(ann) = ann {
1042 let mut ann_vars = BTreeMap::new();
1043 let ann_ty = type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
1044 unifier.unify(&def_ty, &ann_ty)?;
1045 }
1046 let binding_ty = binding_tys
1047 .get(&var.name)
1048 .cloned()
1049 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1050 unifier.unify(&binding_ty, &def_ty)?;
1051 let resolved_ty = unifier.apply_type(&binding_ty);
1052
1053 if let Some(known_variant) =
1054 known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
1055 {
1056 known_seed.insert(
1057 var.name.clone(),
1058 KnownVariant {
1059 adt: known_variant.adt,
1060 variant: known_variant.variant,
1061 },
1062 );
1063 } else {
1064 known_seed.remove(&var.name);
1065 }
1066 inferred.push((var.name.clone(), preds, resolved_ty));
1067 }
1068
1069 let mut env_body = env.clone();
1070 for (name, preds, def_ty) in inferred {
1071 let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
1072 reject_ambiguous_scheme(&scheme)?;
1073 env_body.extend(name, scheme);
1074 }
1075
1076 let (p_body, t_body) =
1077 infer_expr_type(unifier, supply, &env_body, adts, &known_seed, body)?;
1078 Ok((p_body, t_body))
1079 }
1080 Expr::Ite(_, cond, then_expr, else_expr) => {
1081 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, cond)?;
1082 unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
1083 let (p2, t2) = infer_expr_type(unifier, supply, env, adts, known, then_expr)?;
1084 let (p3, t3) = infer_expr_type(unifier, supply, env, adts, known, else_expr)?;
1085 unifier.unify(&t2, &t3)?;
1086 let out_ty = unifier.apply_type(&t2);
1087 let mut preds = p1;
1088 preds.extend(p2);
1089 preds.extend(p3);
1090 Ok((preds, out_ty))
1091 }
1092 Expr::Tuple(_, elems) => {
1093 let mut preds = Vec::new();
1094 let mut types = Vec::new();
1095 for elem in elems {
1096 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
1097 preds.extend(p1);
1098 types.push(unifier.apply_type(&t1));
1099 }
1100 let tuple_ty = Type::tuple(types);
1101 Ok((preds, tuple_ty))
1102 }
1103 Expr::List(_, elems) => {
1104 let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("a"))));
1105 let mut preds = Vec::new();
1106 for elem in elems {
1107 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, elem.as_ref())?;
1108 unifier.unify(&t1, &elem_tv)?;
1109 preds.extend(p1);
1110 }
1111 let list_ty = Type::app(
1112 Type::builtin(BuiltinTypeId::List),
1113 unifier.apply_type(&elem_tv),
1114 );
1115 Ok((preds, list_ty))
1116 }
1117 Expr::Dict(_, kvs) => {
1118 let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("v"))));
1119 let mut preds = Vec::new();
1120 for v in kvs.values() {
1121 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, v.as_ref())?;
1122 unifier.unify(&t1, &elem_tv)?;
1123 preds.extend(p1);
1124 }
1125 let dict_ty = Type::app(
1126 Type::builtin(BuiltinTypeId::Dict),
1127 unifier.apply_type(&elem_tv),
1128 );
1129 Ok((preds, dict_ty))
1130 }
1131 Expr::Match(_, scrutinee, arms) => {
1132 let (p1, t1) = infer_expr_type(unifier, supply, env, adts, known, scrutinee.as_ref())?;
1133 let mut preds = p1;
1134 let res_ty = Type::var(supply.fresh(Some(Symbol::intern("match"))));
1135 let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
1136
1137 for (pat, expr) in arms {
1138 let scrutinee_ty = unifier.apply_type(&t1);
1139 let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
1140 preds.extend(p_pat);
1141
1142 let mut env_arm = env.clone();
1143 for (name, ty) in binds {
1144 env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1145 }
1146 let mut known_arm = known.clone();
1147 if let Expr::Var(var) = scrutinee.as_ref() {
1148 match pat {
1149 Pattern::Named(_, name, _) => {
1150 let name_sym = name.to_dotted_symbol();
1151 if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1152 known_arm.insert(
1153 var.name.clone(),
1154 KnownVariant {
1155 adt: adt.name.clone(),
1156 variant: name_sym,
1157 },
1158 );
1159 } else {
1160 known_arm.remove(&var.name);
1161 }
1162 }
1163 _ => {
1164 known_arm.remove(&var.name);
1165 }
1166 }
1167 }
1168 let (p_expr, t_expr) =
1169 infer_expr_type(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1170 unifier.unify(&res_ty, &t_expr)?;
1171 preds.extend(p_expr);
1172 }
1173
1174 let scrutinee_ty = unifier.apply_type(&t1);
1175 check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1176 let out_ty = unifier.apply_type(&res_ty);
1177 Ok((preds, out_ty))
1178 }
1179 Expr::Ann(_, expr, ann) => {
1180 let ann_ty = type_from_annotation_expr(adts, ann)?;
1181 match expr.as_ref() {
1182 Expr::RecordUpdate(_, base, updates) => {
1183 let (preds, out_ty) = infer_record_update_type_with_hint(
1184 unifier,
1185 supply,
1186 env,
1187 adts,
1188 known,
1189 base.as_ref(),
1190 updates,
1191 &ann_ty,
1192 )?;
1193 Ok((preds, out_ty))
1194 }
1195 _ => {
1196 let (preds, expr_ty) =
1197 infer_expr_type(unifier, supply, env, adts, known, expr)?;
1198 unifier.unify(&expr_ty, &ann_ty)?;
1199 let out_ty = unifier.apply_type(&ann_ty);
1200 Ok((preds, out_ty))
1201 }
1202 }
1203 }
1204 }
1205}
1206
1207fn infer_expr(
1208 unifier: &mut Unifier,
1209 supply: &mut TypeVarSupply,
1210 env: &TypeEnv,
1211 adts: &BTreeMap<Symbol, AdtDecl>,
1212 known: &KnownVariants,
1213 expr: &Expr,
1214) -> Result<(Vec<Predicate>, Type, TypedExpr), TypeError> {
1215 let span = *expr.span();
1216 let res = unifier.with_infer_depth(span, |unifier| {
1217 (|| match expr {
1218 Expr::Bool(_, v) => {
1219 let t = Type::builtin(BuiltinTypeId::Bool);
1220 Ok((
1221 vec![],
1222 t.clone(),
1223 TypedExpr::new(t, TypedExprKind::Bool(*v)),
1224 ))
1225 }
1226 Expr::Uint(_, v) => {
1227 let t = Type::var(supply.fresh(Some(Symbol::intern("n"))));
1228 Ok((
1229 vec![Predicate::new("Integral", t.clone())],
1230 t.clone(),
1231 TypedExpr::new(t, TypedExprKind::Uint(*v)),
1232 ))
1233 }
1234 Expr::Int(_, v) => {
1235 let t = Type::var(supply.fresh(Some(Symbol::intern("n"))));
1236 Ok((
1237 vec![
1238 Predicate::new("Integral", t.clone()),
1239 Predicate::new("AdditiveGroup", t.clone()),
1240 ],
1241 t.clone(),
1242 TypedExpr::new(t, TypedExprKind::Int(*v)),
1243 ))
1244 }
1245 Expr::Float(_, v) => {
1246 let t = Type::builtin(BuiltinTypeId::F32);
1247 Ok((
1248 vec![],
1249 t.clone(),
1250 TypedExpr::new(t, TypedExprKind::Float(*v)),
1251 ))
1252 }
1253 Expr::String(_, v) => {
1254 let t = Type::builtin(BuiltinTypeId::String);
1255 Ok((
1256 vec![],
1257 t.clone(),
1258 TypedExpr::new(t, TypedExprKind::String(v.clone())),
1259 ))
1260 }
1261 Expr::Uuid(_, v) => {
1262 let t = Type::builtin(BuiltinTypeId::Uuid);
1263 Ok((
1264 vec![],
1265 t.clone(),
1266 TypedExpr::new(t, TypedExprKind::Uuid(*v)),
1267 ))
1268 }
1269 Expr::DateTime(_, v) => {
1270 let t = Type::builtin(BuiltinTypeId::DateTime);
1271 Ok((
1272 vec![],
1273 t.clone(),
1274 TypedExpr::new(t, TypedExprKind::DateTime(*v)),
1275 ))
1276 }
1277 Expr::Hole(_) => {
1278 let t = Type::var(supply.fresh(Some(Symbol::intern("hole"))));
1279 Ok((vec![], t.clone(), TypedExpr::new(t, TypedExprKind::Hole)))
1280 }
1281 Expr::Var(var) => {
1282 let schemes = env
1283 .lookup(&var.name)
1284 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1285 if schemes.len() == 1 {
1286 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
1287 let (preds, t) = instantiate(&scheme, supply);
1288 let typed = TypedExpr::new(
1289 t.clone(),
1290 TypedExprKind::Var {
1291 name: var.name.clone(),
1292 overloads: vec![],
1293 },
1294 );
1295 Ok((preds, t, typed))
1296 } else {
1297 let mut overloads = Vec::new();
1298 for scheme in schemes {
1299 if !scheme.preds.is_empty() {
1300 return Err(TypeError::AmbiguousOverload(var.name.clone()));
1301 }
1302
1303 let scheme = apply_scheme_with_unifier(scheme, unifier);
1304 let (preds, typ) = instantiate(&scheme, supply);
1305 if !preds.is_empty() {
1306 return Err(TypeError::AmbiguousOverload(var.name.clone()));
1307 }
1308 overloads.push(typ);
1309 }
1310 let t = Type::var(supply.fresh(Some(var.name.clone())));
1311 let typed = TypedExpr::new(
1312 t.clone(),
1313 TypedExprKind::Var {
1314 name: var.name.clone(),
1315 overloads,
1316 },
1317 );
1318 Ok((vec![], t, typed))
1319 }
1320 }
1321 Expr::Lam(..) => {
1322 let (params, constraints, body) = collect_lambda_chain(expr);
1323 let mut ann_vars = BTreeMap::new();
1324 let mut param_tys = Vec::with_capacity(params.len());
1325 for (name, ann) in ¶ms {
1326 let param_ty = match ann {
1327 Some(ann) => {
1328 type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?
1329 }
1330 None => Type::var(supply.fresh(Some(name.clone()))),
1331 };
1332 param_tys.push((name.clone(), param_ty));
1333 }
1334
1335 let mut env1 = env.clone();
1336 let mut known_body = known.clone();
1337 for (name, param_ty) in ¶m_tys {
1338 env1.extend(name.clone(), Scheme::new(vec![], vec![], param_ty.clone()));
1339 known_body.remove(name);
1340 }
1341
1342 let (mut preds, body_ty, typed_body) =
1343 infer_expr(unifier, supply, &env1, adts, &known_body, body)?;
1344 let constraint_preds =
1345 predicates_from_constraints(adts, &constraints, &mut ann_vars, supply)?;
1346 preds.extend(constraint_preds);
1347
1348 let mut typed = typed_body;
1349 let mut fun_ty = unifier.apply_type(&body_ty);
1350 for (name, param_ty) in param_tys.iter().rev() {
1351 fun_ty = Type::fun(unifier.apply_type(param_ty), fun_ty);
1352 typed = TypedExpr::new(
1353 fun_ty.clone(),
1354 TypedExprKind::Lam {
1355 param: name.clone(),
1356 body: Arc::new(typed),
1357 },
1358 );
1359 }
1360
1361 Ok((preds, fun_ty, typed))
1362 }
1363 Expr::App(..) => {
1364 if let Some((leaf, frames)) = collect_tail_app_chain(expr) {
1365 return infer_tail_app_chain_typed(
1366 unifier, supply, env, adts, known, leaf, frames,
1367 );
1368 }
1369 let (head, args) = collect_app_chain(expr);
1370 let mut state = infer_typed_app_head(unifier, supply, env, adts, known, head)?;
1371 for arg in args {
1372 infer_typed_app_expr_arg(unifier, supply, env, adts, known, &mut state, arg)?;
1373 }
1374 Ok((state.preds, state.func_ty, state.typed))
1375 }
1376 Expr::Project(_, base, field) => {
1377 let (p1, t1, typed_base) = infer_expr(unifier, supply, env, adts, known, base)?;
1378 let base_ty = unifier.apply_type(&t1);
1379 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
1380 let field_ty =
1381 resolve_projection(unifier, supply, adts, &base_ty, known_variant, field)?;
1382 let typed = TypedExpr::new(
1383 field_ty.clone(),
1384 TypedExprKind::Project {
1385 expr: Arc::new(typed_base),
1386 field: field.clone(),
1387 },
1388 );
1389 Ok((p1, field_ty, typed))
1390 }
1391 Expr::RecordUpdate(_, base, updates) => {
1392 let (p_base, t_base, typed_base) =
1393 infer_expr(unifier, supply, env, adts, known, base)?;
1394 let base_ty = unifier.apply_type(&t_base);
1395 let known_variant = known_variant_from_expr_with_known(base, &base_ty, adts, known);
1396 let update_fields: Vec<Symbol> = updates.keys().cloned().collect();
1397 let (result_ty, fields) = resolve_record_update(
1398 unifier,
1399 supply,
1400 adts,
1401 &base_ty,
1402 known_variant,
1403 &update_fields,
1404 )?;
1405 let expected: BTreeMap<_, _> = fields.into_iter().collect();
1406
1407 let mut preds = p_base;
1408 let mut typed_updates = BTreeMap::new();
1409 for (k, v) in updates {
1410 let expected_ty = expected.get(k).ok_or_else(|| TypeError::UnknownField {
1411 field: k.clone(),
1412 typ: result_ty.to_string(),
1413 })?;
1414 let (p1, t1, typed_v) =
1415 infer_expr(unifier, supply, env, adts, known, v.as_ref())?;
1416 unifier.unify(&t1, expected_ty)?;
1417 preds.extend(p1);
1418 typed_updates.insert(k.clone(), Arc::new(typed_v));
1419 }
1420 let typed = TypedExpr::new(
1421 result_ty.clone(),
1422 TypedExprKind::RecordUpdate {
1423 base: Arc::new(typed_base),
1424 updates: typed_updates,
1425 },
1426 );
1427 Ok((preds, result_ty, typed))
1428 }
1429 Expr::Let(..) => {
1430 let mut bindings = Vec::new();
1431 let mut cur = expr;
1432 while let Expr::Let(_, v, ann, d, b) = cur {
1433 bindings.push((v.clone(), ann.clone(), d.clone()));
1434 cur = b.as_ref();
1435 }
1436
1437 let mut env_cur = env.clone();
1438 let mut known_cur = known.clone();
1439 let mut typed_defs = Vec::new();
1440 for (v, ann, d) in bindings {
1441 let (p1, t1, typed_def) = if let Some(ref ann_expr) = ann {
1442 let mut ann_vars = BTreeMap::new();
1443 let ann_ty =
1444 type_from_annotation_expr_vars(adts, ann_expr, &mut ann_vars, supply)?;
1445 match d.as_ref() {
1446 Expr::RecordUpdate(_, base, updates) => {
1447 infer_record_update_typed_with_hint(
1448 unifier,
1449 supply,
1450 &env_cur,
1451 adts,
1452 &known_cur,
1453 base.as_ref(),
1454 updates,
1455 &ann_ty,
1456 )?
1457 }
1458 _ => {
1459 let (p1, t1, typed_def) =
1460 infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?;
1461 unifier.unify(&t1, &ann_ty)?;
1462 (p1, t1, typed_def)
1463 }
1464 }
1465 } else {
1466 infer_expr(unifier, supply, &env_cur, adts, &known_cur, &d)?
1467 };
1468 let def_ty = unifier.apply_type(&t1);
1469 let scheme = if ann.is_none() && is_integral_literal_expr(&d) {
1470 monomorphic_scheme_with_unifier(p1, def_ty.clone(), unifier)
1471 } else {
1472 let scheme = generalize_with_unifier(&env_cur, p1, def_ty.clone(), unifier);
1473 reject_ambiguous_scheme(&scheme)?;
1474 scheme
1475 };
1476 env_cur.extend(v.name.clone(), scheme);
1477 if let Some(known_variant) =
1478 known_variant_from_expr_with_known(&d, &def_ty, adts, &known_cur)
1479 {
1480 known_cur.insert(
1481 v.name.clone(),
1482 KnownVariant {
1483 adt: known_variant.adt,
1484 variant: known_variant.variant,
1485 },
1486 );
1487 } else {
1488 known_cur.remove(&v.name);
1489 }
1490 typed_defs.push((v.name.clone(), typed_def));
1491 }
1492
1493 let (p_body, t_body, typed_body) =
1494 infer_expr(unifier, supply, &env_cur, adts, &known_cur, cur)?;
1495
1496 let mut typed = typed_body;
1497 for (name, def) in typed_defs.into_iter().rev() {
1498 typed = TypedExpr::new(
1499 t_body.clone(),
1500 TypedExprKind::Let {
1501 name,
1502 def: Arc::new(def),
1503 body: Arc::new(typed),
1504 },
1505 );
1506 }
1507 Ok((p_body, t_body, typed))
1508 }
1509 Expr::LetRec(_, bindings, body) => {
1510 let mut env_seed = env.clone();
1511 let mut known_seed = known.clone();
1512 let mut binding_tys = BTreeMap::new();
1513 for (var, _ann, _def) in bindings {
1514 let tv = Type::var(supply.fresh(Some(var.name.clone())));
1515 env_seed.extend(var.name.clone(), Scheme::new(vec![], vec![], tv.clone()));
1516 known_seed.remove(&var.name);
1517 binding_tys.insert(var.name.clone(), tv);
1518 }
1519
1520 let mut inferred_defs = Vec::with_capacity(bindings.len());
1521 for (var, ann, def) in bindings {
1522 let (preds, def_ty, typed_def) =
1523 infer_expr(unifier, supply, &env_seed, adts, &known_seed, def)?;
1524 if let Some(ann) = ann {
1525 let mut ann_vars = BTreeMap::new();
1526 let ann_ty =
1527 type_from_annotation_expr_vars(adts, ann, &mut ann_vars, supply)?;
1528 unifier.unify(&def_ty, &ann_ty)?;
1529 }
1530 let binding_ty = binding_tys
1531 .get(&var.name)
1532 .cloned()
1533 .ok_or_else(|| TypeError::UnknownVar(var.name.clone()))?;
1534 unifier.unify(&binding_ty, &def_ty)?;
1535 let resolved_ty = unifier.apply_type(&binding_ty);
1536
1537 if let Some(known_variant) =
1538 known_variant_from_expr_with_known(def, &resolved_ty, adts, &known_seed)
1539 {
1540 known_seed.insert(
1541 var.name.clone(),
1542 KnownVariant {
1543 adt: known_variant.adt,
1544 variant: known_variant.variant,
1545 },
1546 );
1547 } else {
1548 known_seed.remove(&var.name);
1549 }
1550 inferred_defs.push((var.name.clone(), preds, resolved_ty, typed_def));
1551 }
1552
1553 let mut env_body = env.clone();
1554 let mut typed_bindings = Vec::with_capacity(inferred_defs.len());
1555 for (name, preds, def_ty, typed_def) in inferred_defs {
1556 let scheme = generalize_with_unifier(&env_body, preds, def_ty, unifier);
1557 reject_ambiguous_scheme(&scheme)?;
1558 env_body.extend(name.clone(), scheme);
1559 typed_bindings.push((name, Arc::new(typed_def)));
1560 }
1561
1562 let (p_body, t_body, typed_body) =
1563 infer_expr(unifier, supply, &env_body, adts, &known_seed, body)?;
1564 let typed = TypedExpr::new(
1565 t_body.clone(),
1566 TypedExprKind::LetRec {
1567 bindings: typed_bindings,
1568 body: Arc::new(typed_body),
1569 },
1570 );
1571 Ok((p_body, t_body, typed))
1572 }
1573 Expr::Ite(_, cond, then_expr, else_expr) => {
1574 let (p1, t1, typed_cond) = infer_expr(unifier, supply, env, adts, known, cond)?;
1575 unifier.unify(&t1, &Type::builtin(BuiltinTypeId::Bool))?;
1576 let (p2, t2, typed_then) =
1577 infer_expr(unifier, supply, env, adts, known, then_expr)?;
1578 let (p3, t3, typed_else) =
1579 infer_expr(unifier, supply, env, adts, known, else_expr)?;
1580 unifier.unify(&t2, &t3)?;
1581 let out_ty = unifier.apply_type(&t2);
1582 let mut preds = p1;
1583 preds.extend(p2);
1584 preds.extend(p3);
1585 let typed = TypedExpr::new(
1586 out_ty.clone(),
1587 TypedExprKind::Ite {
1588 cond: Arc::new(typed_cond),
1589 then_expr: Arc::new(typed_then),
1590 else_expr: Arc::new(typed_else),
1591 },
1592 );
1593 Ok((preds, out_ty, typed))
1594 }
1595 Expr::Tuple(_, elems) => {
1596 let mut preds = Vec::new();
1597 let mut types = Vec::new();
1598 let mut typed_elems = Vec::new();
1599 for elem in elems {
1600 let (p1, t1, typed_elem) = infer_expr(unifier, supply, env, adts, known, elem)?;
1601 preds.extend(p1);
1602 types.push(unifier.apply_type(&t1));
1603 typed_elems.push(Arc::new(typed_elem));
1604 }
1605 let tuple_ty = Type::tuple(types);
1606 let typed = TypedExpr::new(tuple_ty.clone(), TypedExprKind::Tuple(typed_elems));
1607 Ok((preds, tuple_ty, typed))
1608 }
1609 Expr::List(_, elems) => {
1610 let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("a"))));
1611 let mut preds = Vec::new();
1612 let mut typed_elems = Vec::new();
1613 for elem in elems {
1614 let (p1, t1, typed_elem) = infer_expr(unifier, supply, env, adts, known, elem)?;
1615 unifier.unify(&t1, &elem_tv)?;
1616 preds.extend(p1);
1617 typed_elems.push(Arc::new(typed_elem));
1618 }
1619 let list_ty = Type::app(
1620 Type::builtin(BuiltinTypeId::List),
1621 unifier.apply_type(&elem_tv),
1622 );
1623 let typed = TypedExpr::new(list_ty.clone(), TypedExprKind::List(typed_elems));
1624 Ok((preds, list_ty, typed))
1625 }
1626 Expr::Dict(_, kvs) => {
1627 let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("v"))));
1628 let mut preds = Vec::new();
1629 let mut typed_kvs = BTreeMap::new();
1630 for (k, v) in kvs {
1631 let (p1, t1, typed_v) = infer_expr(unifier, supply, env, adts, known, v)?;
1632 unifier.unify(&t1, &elem_tv)?;
1633 preds.extend(p1);
1634 typed_kvs.insert(k.clone(), Arc::new(typed_v));
1635 }
1636 let dict_ty = Type::app(
1637 Type::builtin(BuiltinTypeId::Dict),
1638 unifier.apply_type(&elem_tv),
1639 );
1640 let typed = TypedExpr::new(dict_ty.clone(), TypedExprKind::Dict(typed_kvs));
1641 Ok((preds, dict_ty, typed))
1642 }
1643 Expr::Match(_, scrutinee, arms) => {
1644 let (p1, t1, typed_scrutinee) =
1645 infer_expr(unifier, supply, env, adts, known, scrutinee)?;
1646 let mut preds = p1;
1647 let mut typed_arms = Vec::new();
1648 let res_ty = Type::var(supply.fresh(Some(Symbol::intern("match"))));
1649 let patterns: Vec<Pattern> = arms.iter().map(|(pat, _)| pat.clone()).collect();
1650
1651 for (pat, expr) in arms {
1652 let scrutinee_ty = unifier.apply_type(&t1);
1653 let (p_pat, binds) = infer_pattern(unifier, supply, env, pat, &scrutinee_ty)?;
1654 preds.extend(p_pat);
1655
1656 let mut env_arm = env.clone();
1657 for (name, ty) in binds {
1658 env_arm.extend(name, Scheme::new(vec![], vec![], unifier.apply_type(&ty)));
1659 }
1660 let mut known_arm = known.clone();
1661 if let Expr::Var(var) = scrutinee.as_ref() {
1662 match pat {
1663 Pattern::Named(_, name, _) => {
1664 let name_sym = name.to_dotted_symbol();
1665 if let Some((adt, _variant)) = ctor_lookup(adts, &name_sym) {
1666 known_arm.insert(
1667 var.name.clone(),
1668 KnownVariant {
1669 adt: adt.name.clone(),
1670 variant: name_sym,
1671 },
1672 );
1673 } else {
1674 known_arm.remove(&var.name);
1675 }
1676 }
1677 _ => {
1678 known_arm.remove(&var.name);
1679 }
1680 }
1681 }
1682 let (p_expr, t_expr, typed_expr) =
1683 infer_expr(unifier, supply, &env_arm, adts, &known_arm, expr)?;
1684 unifier.unify(&res_ty, &t_expr)?;
1685 preds.extend(p_expr);
1686 typed_arms.push((pat.clone(), Arc::new(typed_expr)));
1687 }
1688
1689 let scrutinee_ty = unifier.apply_type(&t1);
1690 check_match_exhaustive(adts, &scrutinee_ty, &patterns)?;
1691 let out_ty = unifier.apply_type(&res_ty);
1692 let typed = TypedExpr::new(
1693 out_ty.clone(),
1694 TypedExprKind::Match {
1695 scrutinee: Arc::new(typed_scrutinee),
1696 arms: typed_arms,
1697 },
1698 );
1699 Ok((preds, out_ty, typed))
1700 }
1701 Expr::Ann(_, expr, ann) => {
1702 let ann_ty = type_from_annotation_expr(adts, ann)?;
1703 match expr.as_ref() {
1704 Expr::RecordUpdate(_, base, updates) => infer_record_update_typed_with_hint(
1705 unifier,
1706 supply,
1707 env,
1708 adts,
1709 known,
1710 base.as_ref(),
1711 updates,
1712 &ann_ty,
1713 ),
1714 _ => {
1715 let (preds, expr_ty, typed_expr) =
1716 infer_expr(unifier, supply, env, adts, known, expr)?;
1717 unifier.unify(&expr_ty, &ann_ty)?;
1718 let out_ty = unifier.apply_type(&ann_ty);
1719 Ok((preds, out_ty, typed_expr))
1720 }
1721 }
1722 }
1723 })()
1724 });
1725 res.map_err(|err| err.with_span(&span))
1726}
1727
1728fn ctor_lookup<'a>(
1729 adts: &'a BTreeMap<Symbol, AdtDecl>,
1730 name: &Symbol,
1731) -> Option<(&'a AdtDecl, &'a AdtVariant)> {
1732 let mut found = None;
1733 for adt in adts.values() {
1734 if let Some(variant) = adt.variants.iter().find(|v| &v.name == name) {
1735 if found.is_some() {
1736 return None;
1737 }
1738 found = Some((adt, variant));
1739 }
1740 }
1741 found
1742}
1743
1744fn record_fields(variant: &AdtVariant) -> Option<&[(Symbol, Type)]> {
1745 if variant.args.len() != 1 {
1746 return None;
1747 }
1748 match variant.args[0].as_ref() {
1749 TypeKind::Record(fields) => Some(fields),
1750 _ => None,
1751 }
1752}
1753
1754fn instantiate_variant_fields(
1755 adt: &AdtDecl,
1756 variant: &AdtVariant,
1757 supply: &mut TypeVarSupply,
1758) -> Option<(Type, Vec<(Symbol, Type)>)> {
1759 let fields = record_fields(variant)?;
1760 let mut subst = Subst::new_sync();
1761 for param in &adt.params {
1762 let fresh = Type::var(supply.fresh(param.var.name.clone()));
1763 subst = subst.insert(param.var.id, fresh);
1764 }
1765 let result_ty = adt.result_type().apply(&subst);
1766 let fields = fields
1767 .iter()
1768 .map(|(name, ty)| (name.clone(), ty.apply(&subst)))
1769 .collect();
1770 Some((result_ty, fields))
1771}
1772
1773fn known_variant_from_expr(
1774 expr: &Expr,
1775 expr_ty: &Type,
1776 adts: &BTreeMap<Symbol, AdtDecl>,
1777) -> Option<KnownVariant> {
1778 let mut expr = expr;
1779 while let Expr::Ann(_, inner, _) = expr {
1780 expr = inner.as_ref();
1781 }
1782 if matches!(expr_ty.as_ref(), TypeKind::Fun(..)) {
1783 return None;
1784 }
1785 let ctor = match expr {
1786 Expr::App(_, f, _) => match f.as_ref() {
1787 Expr::Var(var) => var.name.clone(),
1788 _ => return None,
1789 },
1790 _ => return None,
1791 };
1792 let (adt, variant) = ctor_lookup(adts, &ctor)?;
1793 record_fields(variant)?;
1794 Some(KnownVariant {
1795 adt: adt.name.clone(),
1796 variant: variant.name.clone(),
1797 })
1798}
1799
1800fn known_variant_from_expr_with_known(
1801 expr: &Expr,
1802 expr_ty: &Type,
1803 adts: &BTreeMap<Symbol, AdtDecl>,
1804 known: &KnownVariants,
1805) -> Option<KnownVariant> {
1806 let mut expr = expr;
1807 while let Expr::Ann(_, inner, _) = expr {
1808 expr = inner.as_ref();
1809 }
1810 match expr {
1811 Expr::Var(var) => known.get(&var.name).cloned(),
1812 Expr::RecordUpdate(_, base, _) => {
1813 known_variant_from_expr_with_known(base.as_ref(), expr_ty, adts, known)
1814 }
1815 _ => known_variant_from_expr(expr, expr_ty, adts),
1816 }
1817}
1818
1819fn select_record_variant<'a, F>(
1820 adts: &'a BTreeMap<Symbol, AdtDecl>,
1821 base_ty: &Type,
1822 known_variant: Option<KnownVariant>,
1823 field_for_errors: &Symbol,
1824 matches_fields: F,
1825) -> Result<(&'a AdtDecl, &'a AdtVariant), TypeError>
1826where
1827 F: Fn(&[(Symbol, Type)]) -> bool,
1828{
1829 if let Some(info) = known_variant {
1830 let adt = adts
1831 .get(&info.adt)
1832 .ok_or_else(|| TypeError::UnknownTypeName(info.adt.clone()))?;
1833 let variant = adt
1834 .variants
1835 .iter()
1836 .find(|v| v.name == info.variant)
1837 .ok_or_else(|| TypeError::UnknownField {
1838 field: field_for_errors.clone(),
1839 typ: base_ty.to_string(),
1840 })?;
1841 return Ok((adt, variant));
1842 }
1843
1844 if let Some(adt_name) = type_head_name(base_ty) {
1845 let adt = adts.get(adt_name).ok_or_else(|| TypeError::UnknownField {
1846 field: field_for_errors.clone(),
1847 typ: base_ty.to_string(),
1848 })?;
1849 if adt.variants.len() == 1 {
1850 return Ok((adt, &adt.variants[0]));
1851 }
1852 return Err(TypeError::FieldNotKnown {
1853 field: field_for_errors.clone(),
1854 typ: base_ty.to_string(),
1855 });
1856 }
1857
1858 if matches!(base_ty.as_ref(), TypeKind::Var(_)) {
1859 let mut candidates = Vec::new();
1860 for adt in adts.values() {
1861 if adt.variants.len() != 1 {
1862 continue;
1863 }
1864 let variant = &adt.variants[0];
1865 let Some(fields) = record_fields(variant) else {
1866 continue;
1867 };
1868 if matches_fields(fields) {
1869 candidates.push((adt, variant));
1870 }
1871 }
1872 if candidates.len() == 1 {
1873 return Ok(candidates.remove(0));
1874 }
1875 if candidates.is_empty() {
1876 return Err(TypeError::UnknownField {
1877 field: field_for_errors.clone(),
1878 typ: base_ty.to_string(),
1879 });
1880 }
1881 return Err(TypeError::FieldNotKnown {
1882 field: field_for_errors.clone(),
1883 typ: base_ty.to_string(),
1884 });
1885 }
1886
1887 Err(TypeError::UnknownField {
1888 field: field_for_errors.clone(),
1889 typ: base_ty.to_string(),
1890 })
1891}
1892
1893fn resolve_record_update(
1894 unifier: &mut Unifier,
1895 supply: &mut TypeVarSupply,
1896 adts: &BTreeMap<Symbol, AdtDecl>,
1897 base_ty: &Type,
1898 known_variant: Option<KnownVariant>,
1899 update_fields: &[Symbol],
1900) -> Result<(Type, Vec<(Symbol, Type)>), TypeError> {
1901 if let TypeKind::Record(fields) = base_ty.as_ref() {
1902 return Ok((base_ty.clone(), fields.clone()));
1903 }
1904
1905 let field_for_errors = update_fields
1906 .first()
1907 .cloned()
1908 .unwrap_or_else(|| Symbol::intern("_"));
1909
1910 let (adt, variant) =
1911 select_record_variant(adts, base_ty, known_variant, &field_for_errors, |fields| {
1912 update_fields
1913 .iter()
1914 .all(|field| fields.iter().any(|(name, _)| name == field))
1915 })?;
1916
1917 let (result_ty, fields) =
1918 instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1919 TypeError::UnknownField {
1920 field: field_for_errors.clone(),
1921 typ: base_ty.to_string(),
1922 }
1923 })?;
1924
1925 for field in update_fields {
1926 if fields.iter().all(|(name, _)| name != field) {
1927 return Err(TypeError::UnknownField {
1928 field: field.clone(),
1929 typ: base_ty.to_string(),
1930 });
1931 }
1932 }
1933
1934 unifier.unify(base_ty, &result_ty)?;
1935 let result_ty = unifier.apply_type(&result_ty);
1936 let fields = fields
1937 .into_iter()
1938 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
1939 .collect();
1940 Ok((result_ty, fields))
1941}
1942
1943fn resolve_projection(
1944 unifier: &mut Unifier,
1945 supply: &mut TypeVarSupply,
1946 adts: &BTreeMap<Symbol, AdtDecl>,
1947 base_ty: &Type,
1948 known_variant: Option<KnownVariant>,
1949 field: &Symbol,
1950) -> Result<Type, TypeError> {
1951 if let Ok(index) = field.as_ref().parse::<usize>() {
1952 let elem_ty = match base_ty.as_ref() {
1953 TypeKind::Tuple(elems) => {
1954 elems
1955 .get(index)
1956 .cloned()
1957 .ok_or_else(|| TypeError::UnknownField {
1958 field: field.clone(),
1959 typ: base_ty.to_string(),
1960 })?
1961 }
1962 TypeKind::Var(_) => {
1963 let mut elems = Vec::with_capacity(index + 1);
1964 for _ in 0..=index {
1965 elems.push(Type::var(supply.fresh(Some(Symbol::intern("t")))));
1966 }
1967 let tuple_ty = Type::tuple(elems.clone());
1968 unifier.unify(base_ty, &tuple_ty)?;
1969 elems[index].clone()
1970 }
1971 _ => {
1972 return Err(TypeError::UnknownField {
1973 field: field.clone(),
1974 typ: base_ty.to_string(),
1975 });
1976 }
1977 };
1978 return Ok(unifier.apply_type(&elem_ty));
1979 }
1980
1981 let (adt, variant) = select_record_variant(adts, base_ty, known_variant, field, |fields| {
1982 fields.iter().any(|(name, _)| name == field)
1983 })?;
1984
1985 let (result_ty, fields) =
1986 instantiate_variant_fields(adt, variant, supply).ok_or_else(|| {
1987 TypeError::UnknownField {
1988 field: field.clone(),
1989 typ: base_ty.to_string(),
1990 }
1991 })?;
1992 let field_ty = fields
1993 .iter()
1994 .find(|(name, _)| name == field)
1995 .map(|(_, ty)| ty.clone())
1996 .ok_or_else(|| TypeError::UnknownField {
1997 field: field.clone(),
1998 typ: base_ty.to_string(),
1999 })?;
2000 unifier.unify(base_ty, &result_ty)?;
2001 Ok(unifier.apply_type(&field_ty))
2002}
2003
2004fn decompose_fun(typ: &Type, arity: usize) -> Option<(Vec<Type>, Type)> {
2005 let mut args = Vec::with_capacity(arity);
2006 let mut cur = typ.clone();
2007 for _ in 0..arity {
2008 match cur.as_ref() {
2009 TypeKind::Fun(a, b) => {
2010 args.push(a.clone());
2011 cur = b.clone();
2012 }
2013 _ => return None,
2014 }
2015 }
2016 Some((args, cur))
2017}
2018
2019type InferPatternResult = (Vec<Predicate>, Vec<(Symbol, Type)>);
2020
2021fn infer_pattern(
2022 unifier: &mut Unifier,
2023 supply: &mut TypeVarSupply,
2024 env: &TypeEnv,
2025 pat: &Pattern,
2026 scrutinee_ty: &Type,
2027) -> Result<InferPatternResult, TypeError> {
2028 let span = *pat.span();
2029 let res = (|| match pat {
2030 Pattern::Wildcard(..) => Ok((vec![], vec![])),
2031 Pattern::Var(var) => Ok((
2032 vec![],
2033 vec![(var.name.clone(), unifier.apply_type(scrutinee_ty))],
2034 )),
2035 Pattern::Named(_, name, ps) => {
2036 let ctor_name = name.to_dotted_symbol();
2037 let schemes = env
2038 .lookup(&ctor_name)
2039 .ok_or_else(|| TypeError::UnknownVar(ctor_name.clone()))?;
2040 if schemes.len() != 1 {
2041 return Err(TypeError::AmbiguousOverload(ctor_name));
2042 }
2043 let scheme = apply_scheme_with_unifier(&schemes[0], unifier);
2044 let (preds, ctor_ty) = instantiate(&scheme, supply);
2045 let (arg_tys, res_ty) = decompose_fun(&ctor_ty, ps.len())
2046 .ok_or(TypeError::UnsupportedExpr("pattern constructor"))?;
2047 unifier.unify(&res_ty, scrutinee_ty)?;
2048 let mut all_preds = preds;
2049 let mut bindings = Vec::new();
2050 for (p, arg_ty) in ps.iter().zip(arg_tys.iter()) {
2051 let arg_ty = unifier.apply_type(arg_ty);
2052 let (p1, binds1) = infer_pattern(unifier, supply, env, p, &arg_ty)?;
2053 all_preds.extend(p1);
2054 bindings.extend(binds1);
2055 }
2056 let bindings = bindings
2057 .into_iter()
2058 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2059 .collect();
2060 Ok((all_preds, bindings))
2061 }
2062 Pattern::List(_, ps) => {
2063 let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("a"))));
2064 let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2065 unifier.unify(scrutinee_ty, &list_ty)?;
2066 let mut preds = Vec::new();
2067 let mut bindings = Vec::new();
2068 for p in ps {
2069 let elem_ty = unifier.apply_type(&elem_tv);
2070 let (p1, binds1) = infer_pattern(unifier, supply, env, p, &elem_ty)?;
2071 preds.extend(p1);
2072 bindings.extend(binds1);
2073 }
2074 let bindings = bindings
2075 .into_iter()
2076 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2077 .collect();
2078 Ok((preds, bindings))
2079 }
2080 Pattern::Cons(_, head, tail) => {
2081 let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("a"))));
2082 let list_ty = Type::app(Type::builtin(BuiltinTypeId::List), elem_tv.clone());
2083 unifier.unify(scrutinee_ty, &list_ty)?;
2084 let mut preds = Vec::new();
2085 let mut bindings = Vec::new();
2086
2087 let head_ty = unifier.apply_type(&elem_tv);
2088 let (p1, binds1) = infer_pattern(unifier, supply, env, head, &head_ty)?;
2089 preds.extend(p1);
2090 bindings.extend(binds1);
2091
2092 let tail_ty = Type::app(
2093 Type::builtin(BuiltinTypeId::List),
2094 unifier.apply_type(&elem_tv),
2095 );
2096 let (p2, binds2) = infer_pattern(unifier, supply, env, tail, &tail_ty)?;
2097 preds.extend(p2);
2098 bindings.extend(binds2);
2099
2100 let bindings = bindings
2101 .into_iter()
2102 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2103 .collect();
2104 Ok((preds, bindings))
2105 }
2106 Pattern::Tuple(_, elems) => {
2107 let mut elem_tys: Vec<Type> = (0..elems.len())
2108 .map(|i| Type::var(supply.fresh(Some(Symbol::intern(&format!("t{i}"))))))
2109 .collect();
2110 let expected = Type::tuple(elem_tys.clone());
2111 unifier.unify(scrutinee_ty, &expected)?;
2112 elem_tys = elem_tys
2113 .into_iter()
2114 .map(|t| unifier.apply_type(&t))
2115 .collect();
2116
2117 let mut preds = Vec::new();
2118 let mut bindings = Vec::new();
2119 for (p, ty) in elems.iter().zip(elem_tys.iter()) {
2120 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, p, ty)?;
2121 preds.extend(p_preds);
2122 bindings.extend(p_binds);
2123 }
2124 let bindings = bindings
2125 .into_iter()
2126 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2127 .collect();
2128 Ok((preds, bindings))
2129 }
2130 Pattern::Dict(_, fields) => {
2131 if let TypeKind::Record(ty_fields) = scrutinee_ty.as_ref() {
2132 let mut preds = Vec::new();
2133 let mut bindings = Vec::new();
2134 for (key, pat) in fields {
2135 let ty = ty_fields
2136 .iter()
2137 .find(|(name, _)| name == key)
2138 .map(|(_, ty)| unifier.apply_type(ty))
2139 .ok_or_else(|| TypeError::UnknownField {
2140 field: key.clone(),
2141 typ: scrutinee_ty.to_string(),
2142 })?;
2143 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &ty)?;
2144 preds.extend(p_preds);
2145 bindings.extend(p_binds);
2146 }
2147 let bindings = bindings
2148 .into_iter()
2149 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2150 .collect();
2151 Ok((preds, bindings))
2152 } else {
2153 let elem_tv = Type::var(supply.fresh(Some(Symbol::intern("v"))));
2154 let dict_ty = Type::app(Type::builtin(BuiltinTypeId::Dict), elem_tv.clone());
2155 unifier.unify(scrutinee_ty, &dict_ty)?;
2156 let elem_ty = unifier.apply_type(&elem_tv);
2157
2158 let mut preds = Vec::new();
2159 let mut bindings = Vec::new();
2160 for (_key, pat) in fields {
2161 let (p_preds, p_binds) = infer_pattern(unifier, supply, env, pat, &elem_ty)?;
2162 preds.extend(p_preds);
2163 bindings.extend(p_binds);
2164 }
2165 let bindings = bindings
2166 .into_iter()
2167 .map(|(name, ty)| (name, unifier.apply_type(&ty)))
2168 .collect();
2169 Ok((preds, bindings))
2170 }
2171 }
2172 })();
2173 res.map_err(|err| err.with_span(&span))
2174}
2175
2176fn type_head_name(typ: &Type) -> Option<&Symbol> {
2177 let mut cur = typ;
2178 while let TypeKind::App(head, _) = cur.as_ref() {
2179 cur = head;
2180 }
2181 match cur.as_ref() {
2182 TypeKind::Con(tc) => tc.user_name(),
2183 _ => None,
2184 }
2185}
2186
2187fn adt_name_from_patterns(
2188 adts: &BTreeMap<Symbol, AdtDecl>,
2189 patterns: &[Pattern],
2190) -> Option<Symbol> {
2191 let mut candidate: Option<Symbol> = None;
2192 for pat in patterns {
2193 let next = match pat {
2194 Pattern::Named(_, name, _) => {
2195 let name_sym = name.to_dotted_symbol();
2196 ctor_lookup(adts, &name_sym).map(|(adt, _)| adt.name.clone())
2197 }
2198 Pattern::List(..) | Pattern::Cons(..) => Some(Symbol::intern("List")),
2199 _ => None,
2200 };
2201 if let Some(next) = next {
2202 match &candidate {
2203 None => candidate = Some(next),
2204 Some(prev) if *prev == next => {}
2205 Some(_) => return None,
2206 }
2207 }
2208 }
2209 candidate
2210}
2211
2212fn check_match_exhaustive(
2213 adts: &BTreeMap<Symbol, AdtDecl>,
2214 scrutinee_ty: &Type,
2215 patterns: &[Pattern],
2216) -> Result<(), TypeError> {
2217 if patterns
2218 .iter()
2219 .any(|p| matches!(p, Pattern::Wildcard(..) | Pattern::Var(_)))
2220 {
2221 return Ok(());
2222 }
2223 let adt_name = match type_head_name(scrutinee_ty).cloned() {
2224 Some(name) => name,
2225 None => match adt_name_from_patterns(adts, patterns) {
2226 Some(name) => name,
2227 None => return Ok(()),
2228 },
2229 };
2230 let adt = match adts.get(&adt_name) {
2231 Some(adt) => adt,
2232 None => return Ok(()),
2233 };
2234 let ctor_names: BTreeSet<Symbol> = adt.variants.iter().map(|v| v.name.clone()).collect();
2235 if ctor_names.is_empty() {
2236 return Ok(());
2237 }
2238 let mut covered = BTreeSet::new();
2239 for pat in patterns {
2240 match pat {
2241 Pattern::Named(_, name, _) => {
2242 let name_sym = name.to_dotted_symbol();
2243 if ctor_names.contains(&name_sym) {
2244 covered.insert(name_sym);
2245 }
2246 }
2247 Pattern::List(_, elems) if adt_name.as_ref() == "List" && elems.is_empty() => {
2248 covered.insert(Symbol::intern("Empty"));
2249 }
2250 Pattern::Cons(..) if adt_name.as_ref() == "List" => {
2251 covered.insert(Symbol::intern("Cons"));
2252 }
2253 _ => {}
2254 }
2255 }
2256 let mut missing: Vec<Symbol> = ctor_names.difference(&covered).cloned().collect();
2257 if missing.is_empty() {
2258 return Ok(());
2259 }
2260 missing.sort();
2261 Err(TypeError::NonExhaustiveMatch {
2262 typ: scrutinee_ty.to_string(),
2263 missing,
2264 })
2265}