1use crate::builtins::{module_for_import, module_scope};
5use crate::env::{TypeDefKind, TypeEnv, ty_from_canon_env};
6use crate::error::{PositionedError, TypeError};
7use crate::position::Position;
8use crate::types::*;
9use crate::unifier::{UnifyError, Unifier};
10use indexmap::IndexMap;
11use lex_ast as a;
12use std::collections::{BTreeMap, HashMap};
13
14type FieldSchema = (Vec<String>, Vec<(String, String)>);
17
18pub struct ProgramTypes {
20 pub fn_signatures: IndexMap<String, Scheme>,
21 pub type_env: TypeEnv,
22 pub parse_required_fields: HashMap<usize, Vec<String>>,
31 pub parse_type_schemas: HashMap<usize, Vec<(String, String)>>,
36}
37
38pub fn check_program_with_positions(
52 stages: &[a::Stage],
53 positions: &BTreeMap<String, Position>,
54) -> Result<ProgramTypes, Vec<PositionedError>> {
55 check_program_inner(stages, Some(positions))
56 .map_err(|errs| errs.into_iter().map(|(e, fn_name)| {
57 let pos = fn_name.as_deref().and_then(|n| positions.get(n)).cloned();
58 PositionedError::new(e, pos)
59 }).collect())
60}
61
62pub fn check_program(stages: &[a::Stage]) -> Result<ProgramTypes, Vec<TypeError>> {
63 check_program_inner(stages, None)
64 .map_err(|errs| errs.into_iter().map(|(e, _)| e).collect())
65}
66
67fn check_program_inner(
68 stages: &[a::Stage],
69 _positions: Option<&BTreeMap<String, Position>>,
70) -> Result<ProgramTypes, Vec<(TypeError, Option<String>)>> {
71 let mut tcx = Checker::new();
72 let mut errors: Vec<(TypeError, Option<String>)> = Vec::new();
75
76 for stage in stages {
78 if let a::Stage::Import(i) = stage {
79 if let Some(mod_name) = module_for_import(&i.reference) {
80 if let Some(ty) = module_scope(mod_name, &tcx.type_env) {
81 tcx.globals.insert(i.alias.clone(), Scheme {
82 vars: collect_vars(&ty),
86 eff_vars: collect_eff_vars(&ty),
87 ty,
88 });
89 tcx.module_aliases.insert(i.alias.clone(), mod_name.to_string());
90 }
91 }
92 }
93 }
94
95 for stage in stages {
97 if let a::Stage::TypeDecl(td) = stage {
98 if let Err(e) = tcx.type_env.add_user_type(&td.name, td.clone()) {
99 errors.push((TypeError::RecursiveTypeWithoutConstructor {
100 at_node: "n_0".into(),
101 name: e,
102 }, None));
103 }
104 }
105 }
106
107 for stage in stages {
109 if let a::Stage::FnDecl(fd) = stage {
110 let scheme = function_scheme(fd, &tcx.type_env);
111 tcx.globals.insert(fd.name.clone(), scheme);
112 tcx.fn_params.insert(fd.name.clone(), fd.params.clone());
116 }
117 }
118
119 let mut signatures = IndexMap::new();
124 for stage in stages {
125 if let a::Stage::FnDecl(fd) = stage {
126 match tcx.check_fn(fd) {
127 Ok(scheme) => { signatures.insert(fd.name.clone(), scheme); }
128 Err(es) => {
129 errors.extend(es.into_iter().map(|e| (e, Some(fd.name.clone()))));
130 }
131 }
132 }
133 }
134
135 if errors.is_empty() {
136 let mut parse_required_fields = HashMap::new();
142 let mut parse_type_schemas = HashMap::new();
143 for (call_ptr, ret_ty) in &tcx.pending_parse_calls {
144 if let Some((fields, schema)) = extract_record_fields_and_schema(&tcx.u, &tcx.type_env, ret_ty) {
145 parse_required_fields.insert(*call_ptr, fields);
146 parse_type_schemas.insert(*call_ptr, schema);
147 }
148 }
149 Ok(ProgramTypes {
150 fn_signatures: signatures,
151 type_env: tcx.type_env,
152 parse_required_fields,
153 parse_type_schemas,
154 })
155 } else {
156 Err(errors)
157 }
158}
159
160pub fn check_and_rewrite_program(
166 stages: &mut [a::Stage],
167) -> Result<ProgramTypes, Vec<TypeError>> {
168 let pt = check_program(&*stages)?;
173 if !pt.parse_required_fields.is_empty() {
174 rewrite_parse_calls(stages, &pt.parse_required_fields, &pt.parse_type_schemas);
175 }
176 Ok(pt)
177}
178
179fn rewrite_parse_calls(
193 stages: &mut [a::Stage],
194 required: &HashMap<usize, Vec<String>>,
195 schemas: &HashMap<usize, Vec<(String, String)>>,
196) {
197 for stage in stages.iter_mut() {
198 if let a::Stage::FnDecl(fd) = stage {
199 rewrite_in_expr(&mut fd.body, required, schemas);
200 }
201 }
202}
203
204fn rewrite_in_expr(
205 expr: &mut a::CExpr,
206 required: &HashMap<usize, Vec<String>>,
207 schemas: &HashMap<usize, Vec<(String, String)>>,
208) {
209 let ptr = expr as *const a::CExpr as usize;
210 let do_rewrite = required.get(&ptr).cloned();
211 let do_schema = schemas.get(&ptr).cloned();
212 match expr {
218 a::CExpr::Call { callee, args } => {
219 rewrite_in_expr(callee, required, schemas);
220 for a in args.iter_mut() { rewrite_in_expr(a, required, schemas); }
221 }
222 a::CExpr::Let { value, body, .. } => {
223 rewrite_in_expr(value, required, schemas);
224 rewrite_in_expr(body, required, schemas);
225 }
226 a::CExpr::Match { scrutinee, arms } => {
227 rewrite_in_expr(scrutinee, required, schemas);
228 for arm in arms.iter_mut() { rewrite_in_expr(&mut arm.body, required, schemas); }
229 }
230 a::CExpr::Block { statements, result } => {
231 for s in statements.iter_mut() { rewrite_in_expr(s, required, schemas); }
232 rewrite_in_expr(result, required, schemas);
233 }
234 a::CExpr::Constructor { args, .. } => {
235 for a in args.iter_mut() { rewrite_in_expr(a, required, schemas); }
236 }
237 a::CExpr::RecordLit { fields } => {
238 for f in fields.iter_mut() { rewrite_in_expr(&mut f.value, required, schemas); }
239 }
240 a::CExpr::TupleLit { items } | a::CExpr::ListLit { items } => {
241 for it in items.iter_mut() { rewrite_in_expr(it, required, schemas); }
242 }
243 a::CExpr::FieldAccess { value, .. } => rewrite_in_expr(value, required, schemas),
244 a::CExpr::Lambda { body, .. } => rewrite_in_expr(body, required, schemas),
245 a::CExpr::BinOp { lhs, rhs, .. } => {
246 rewrite_in_expr(lhs, required, schemas);
247 rewrite_in_expr(rhs, required, schemas);
248 }
249 a::CExpr::UnaryOp { expr, .. } => rewrite_in_expr(expr, required, schemas),
250 a::CExpr::Return { value } => rewrite_in_expr(value, required, schemas),
251 a::CExpr::Literal { .. } | a::CExpr::Var { .. } => {}
252 }
253 if let Some(fields) = do_rewrite {
254 match expr {
255 a::CExpr::Call { callee, args } => {
256 if let a::CExpr::FieldAccess { field, .. } = callee.as_mut() {
257 debug_assert_eq!(field, "parse",
258 "rewrite_in_expr: only `.parse` calls should be in the table");
259 *field = "parse_strict_typed".to_string();
262 }
263 args.push(a::CExpr::ListLit {
265 items: fields.into_iter()
266 .map(|f| a::CExpr::Literal {
267 value: a::CLit::Str { value: f },
268 })
269 .collect(),
270 });
271 let schema = do_schema.unwrap_or_default();
273 args.push(a::CExpr::ListLit {
274 items: schema.into_iter()
275 .map(|(name, tag)| a::CExpr::TupleLit {
276 items: vec![
277 a::CExpr::Literal { value: a::CLit::Str { value: name } },
278 a::CExpr::Literal { value: a::CLit::Str { value: tag } },
279 ],
280 })
281 .collect(),
282 });
283 }
284 _ => unreachable!("rewrite table key must point to a Call expression"),
285 }
286 }
287}
288
289fn extract_record_fields_and_schema(
295 u: &Unifier,
296 env: &TypeEnv,
297 ty: &Ty,
298) -> Option<FieldSchema> {
299 let resolved = u.resolve(ty);
300 let Ty::Con(ref name, ref args) = resolved else { return None; };
301 if name != "Result" || args.len() != 2 { return None; }
302 let ok_ty = u.resolve(&args[0]);
303 let unfolded = unfold_record_alias_static(env, ok_ty);
304 if let Ty::Record(fields) = unfolded {
305 let names: Vec<String> = fields.keys().cloned().collect();
306 let schema: Vec<(String, String)> = fields.iter()
307 .map(|(k, v)| (k.clone(), ty_to_tag(u, v)))
308 .collect();
309 Some((names, schema))
310 } else {
311 None
312 }
313}
314
315fn ty_to_tag(u: &Unifier, ty: &Ty) -> String {
319 let resolved = u.resolve(ty);
320 match &resolved {
321 Ty::Prim(Prim::Int) => "Int".to_string(),
322 Ty::Prim(Prim::Float) => "Float".to_string(),
323 Ty::Prim(Prim::Bool) => "Bool".to_string(),
324 Ty::Prim(Prim::Str) => "Str".to_string(),
325 Ty::Con(name, args) if name == "Option" && args.len() == 1 => {
326 format!("Option[{}]", ty_to_tag(u, &args[0]))
327 }
328 Ty::List(inner) => {
329 format!("List[{}]", ty_to_tag(u, inner))
330 }
331 Ty::Record(_) => "Record".to_string(),
332 _ => "Any".to_string(),
333 }
334}
335
336fn unfold_record_alias_static(env: &TypeEnv, ty: Ty) -> Ty {
342 if let Ty::Con(ref n, ref args) = ty {
343 if let Some(td) = env.types.get(n) {
344 if let TypeDefKind::Alias(inner) = &td.kind {
345 if td.params.len() != args.len() {
346 return ty;
347 }
348 if td.params.is_empty() {
349 return inner.clone();
350 }
351 let mut subst = IndexMap::new();
352 for (i, a) in args.iter().enumerate() {
353 subst.insert(i as u32, a.clone());
354 }
355 return subst_vars(inner, &subst, &IndexMap::new());
356 }
357 }
358 }
359 ty
360}
361
362fn collect_vars(t: &Ty) -> Vec<TyVarId> {
363 let mut out = Vec::new();
364 fn walk(t: &Ty, out: &mut Vec<TyVarId>) {
365 match t {
366 Ty::Var(v) => { if !out.contains(v) { out.push(*v); } }
367 Ty::Prim(_) | Ty::Unit | Ty::Never => {}
368 Ty::List(inner) => walk(inner, out),
369 Ty::Tuple(items) => for it in items { walk(it, out); },
370 Ty::Record(fs) => for v in fs.values() { walk(v, out); },
371 Ty::Con(_, args) => for a in args { walk(a, out); },
372 Ty::Function { params, ret, .. } => {
373 for p in params { walk(p, out); }
374 walk(ret, out);
375 }
376 }
377 }
378 walk(t, &mut out);
379 out
380}
381
382fn collect_eff_vars(t: &Ty) -> Vec<u32> {
386 let mut out = Vec::new();
387 fn walk(t: &Ty, out: &mut Vec<u32>) {
388 match t {
389 Ty::Var(_) | Ty::Prim(_) | Ty::Unit | Ty::Never => {}
390 Ty::List(inner) => walk(inner, out),
391 Ty::Tuple(items) => for it in items { walk(it, out); },
392 Ty::Record(fs) => for v in fs.values() { walk(v, out); },
393 Ty::Con(_, args) => for a in args { walk(a, out); },
394 Ty::Function { params, effects, ret } => {
395 if let Some(v) = effects.var {
396 if !out.contains(&v) { out.push(v); }
397 }
398 for p in params { walk(p, out); }
399 walk(ret, out);
400 }
401 }
402 }
403 walk(t, &mut out);
404 out
405}
406
407fn function_scheme(fd: &a::FnDecl, env: &TypeEnv) -> Scheme {
408 let params: Vec<Ty> = fd.params.iter().map(|p| ty_from_canon_env(&p.ty, &fd.type_params, env)).collect();
410 let ret = ty_from_canon_env(&fd.return_type, &fd.type_params, env);
411 let effects = EffectSet {
415 concrete: {
416 let mut s = std::collections::BTreeSet::new();
417 for e in &fd.effects {
418 let arg = e.arg.as_ref().map(|a| match a {
419 a::EffectArg::Str { value } => crate::types::EffectArg::Str(value.clone()),
420 a::EffectArg::Int { value } => crate::types::EffectArg::Int(*value),
421 a::EffectArg::Ident { value } => crate::types::EffectArg::Ident(value.clone()),
422 });
423 s.insert(crate::types::EffectKind { name: e.name.clone(), arg });
424 }
425 s
426 },
427 var: None,
428 };
429 let ty = Ty::Function { params, effects, ret: Box::new(ret) };
430 let vars: Vec<TyVarId> = (0..fd.type_params.len() as u32).collect();
431 Scheme { vars, eff_vars: Vec::new(), ty }
435}
436
437struct Checker {
438 u: Unifier,
439 type_env: TypeEnv,
440 globals: IndexMap<String, Scheme>,
441 module_aliases: IndexMap<String, String>,
445 pending_parse_calls: Vec<(usize, Ty)>,
453 fn_params: IndexMap<String, Vec<a::Param>>,
459 recovered_errors: Vec<TypeError>,
465}
466
467impl Checker {
468 fn new() -> Self {
469 Self {
470 u: Unifier::new(),
471 type_env: TypeEnv::new_with_builtins(),
472 globals: IndexMap::new(),
473 module_aliases: IndexMap::new(),
474 pending_parse_calls: Vec::new(),
475 fn_params: IndexMap::new(),
476 recovered_errors: Vec::new(),
477 }
478 }
479
480 fn check_expr_recover(
488 &mut self,
489 e: &a::CExpr,
490 node_id: &str,
491 locals: &mut IndexMap<String, Ty>,
492 effs: &mut EffectSet,
493 ) -> Ty {
494 match self.check_expr(e, node_id, locals, effs) {
495 Ok(ty) => ty,
496 Err(err) => {
497 self.recovered_errors.push(err);
498 self.u.fresh()
499 }
500 }
501 }
502
503 fn unfold_record_alias(&self, ty: Ty) -> Ty {
513 if let Ty::Con(ref n, ref args) = ty {
514 if let Some(td) = self.type_env.types.get(n) {
515 if let TypeDefKind::Alias(inner) = &td.kind {
516 if td.params.len() != args.len() {
517 return ty;
518 }
519 if td.params.is_empty() {
520 return inner.clone();
521 }
522 let mut subst = IndexMap::new();
523 for (i, a) in args.iter().enumerate() {
524 subst.insert(i as u32, a.clone());
525 }
526 return subst_vars(inner, &subst, &IndexMap::new());
527 }
528 }
529 }
530 ty
531 }
532
533 fn is_alias_con(&self, ty: &Ty) -> bool {
540 if let Ty::Con(name, args) = ty {
541 if let Some(td) = self.type_env.types.get(name) {
542 if matches!(td.kind, TypeDefKind::Alias(_))
543 && td.params.len() == args.len()
544 {
545 return true;
546 }
547 }
548 }
549 false
550 }
551
552 fn is_module_parse_call(&self, callee: &a::CExpr) -> bool {
557 if let a::CExpr::FieldAccess { value, field } = callee {
558 if field != "parse" { return false; }
559 if let a::CExpr::Var { name } = value.as_ref() {
560 if let Some(module) = self.module_aliases.get(name) {
561 return matches!(module.as_str(), "json" | "toml" | "yaml");
562 }
563 }
564 }
565 false
566 }
567
568 fn unify_with_record_coercion(&mut self, a: &Ty, b: &Ty) -> Result<(), UnifyError> {
580 let a = self.u.resolve(a);
581 let b = self.u.resolve(b);
582 self.unify_coerce_inner(a, b)
583 }
584
585 fn unify_coerce_inner(&mut self, a: Ty, b: Ty) -> Result<(), UnifyError> {
586 let (a, b) = match (&a, &b) {
612 (Ty::Con(n1, _), Ty::Con(n2, _)) if n1 == n2 => (a, b),
613 (Ty::Var(_), _) | (_, Ty::Var(_)) => (a, b),
614 (Ty::Con(_, _), Ty::Con(_, _))
615 if self.is_alias_con(&a) && self.is_alias_con(&b) =>
616 {
617 (a, b)
618 }
619 _ => {
620 let a_u = if let Ty::Con(_, _) = &a {
621 self.unfold_record_alias(a.clone())
622 } else {
623 a
624 };
625 let b_u = if let Ty::Con(_, _) = &b {
626 self.unfold_record_alias(b.clone())
627 } else {
628 b
629 };
630 (a_u, b_u)
631 }
632 };
633
634 match (&a, &b) {
635 (Ty::Record(fa), Ty::Record(fb)) => {
636 if fa.len() != fb.len() {
637 return Err(UnifyError::Mismatch { a: a.clone(), b: b.clone() });
638 }
639 for (k, va) in fa.clone() {
640 match fb.get(&k) {
641 Some(vb) => self.unify_coerce_inner(va, vb.clone())?,
642 None => return Err(UnifyError::Mismatch { a: a.clone(), b: b.clone() }),
643 }
644 }
645 Ok(())
646 }
647 (Ty::List(ta), Ty::List(tb)) => {
648 self.unify_coerce_inner((**ta).clone(), (**tb).clone())
649 }
650 (Ty::Tuple(xs), Ty::Tuple(ys)) if xs.len() == ys.len() => {
651 for (x, y) in xs.clone().into_iter().zip(ys.clone()) {
652 self.unify_coerce_inner(x, y)?;
653 }
654 Ok(())
655 }
656 (Ty::Con(n1, a1), Ty::Con(n2, a2)) if n1 == n2 && a1.len() == a2.len() => {
659 for (x, y) in a1.clone().into_iter().zip(a2.clone()) {
660 self.unify_coerce_inner(x, y)?;
661 }
662 Ok(())
663 }
664 (Ty::Function { params: pa, effects: ea, ret: ra },
669 Ty::Function { params: pb, effects: eb, ret: rb })
670 if pa.len() == pb.len() => {
671 for (x, y) in pa.clone().into_iter().zip(pb.clone()) {
672 self.unify_coerce_inner(x, y)?;
673 }
674 self.u.unify_effects(ea, eb)?;
679 self.unify_coerce_inner((**ra).clone(), (**rb).clone())
680 }
681 _ => self.u.unify(&a, &b),
682 }
683 }
684
685 fn check_fn(&mut self, fd: &a::FnDecl) -> Result<Scheme, Vec<TypeError>> {
686 let scheme = function_scheme(fd, &self.type_env);
688 let (param_tys, declared_effects, ret_ty) = match instantiate(&scheme, &mut self.u) {
689 Ty::Function { params, effects, ret } => (params, effects, *ret),
690 _ => unreachable!(),
691 };
692
693 let mut locals: IndexMap<String, Ty> = IndexMap::new();
694 for (p, t) in fd.params.iter().zip(param_tys.iter()) {
695 locals.insert(p.name.clone(), t.clone());
696 }
697
698 let mut errors: Vec<TypeError> = Vec::new();
702 let mut inferred_effects = EffectSet::empty();
703
704 let body_ok = match self.check_expr(&fd.body, "n_0", &mut locals, &mut inferred_effects) {
706 Ok(body_ty) => {
707 if let Err(e) = self.unify_with_record_coercion(&body_ty, &ret_ty) {
713 errors.push(mismatch_err("n_0", e, &self.u, vec![format!("in function `{}`", fd.name)]));
714 false
715 } else {
716 true
717 }
718 }
719 Err(e) => { errors.push(e); false }
720 };
721
722 let body_had_recovered = !self.recovered_errors.is_empty();
726 errors.append(&mut self.recovered_errors);
727
728 if body_ok && !body_had_recovered && !inferred_effects.is_subset(&declared_effects) {
733 for e in inferred_effects.concrete.iter() {
734 if !declared_effects.concrete.iter().any(|d| d.subsumes(e)) {
735 errors.push(TypeError::EffectNotDeclared {
736 at_node: "n_0".into(),
737 effect: e.pretty(),
738 });
739 break;
740 }
741 }
742 }
743
744 if !fd.examples.is_empty() {
749 if !declared_effects.concrete.is_empty() {
750 errors.push(TypeError::ExamplesOnEffectfulFn {
751 at_node: "n_0".into(),
752 fn_name: fd.name.clone(),
753 });
754 } else {
755 for (case_index, ex) in fd.examples.iter().enumerate() {
756 if ex.args.len() != param_tys.len() {
757 errors.push(TypeError::ExampleArityMismatch {
758 at_node: "n_0".into(),
759 fn_name: fd.name.clone(),
760 case_index,
761 expected: param_tys.len(),
762 got: ex.args.len(),
763 });
764 continue;
765 }
766 let mut example_locals: IndexMap<String, Ty> = IndexMap::new();
767 let mut example_effects = EffectSet::empty();
768 let mut args_ok = true;
769 for (i, (arg, expected_ty)) in
770 ex.args.iter().zip(param_tys.iter()).enumerate()
771 {
772 match self.check_expr(arg, "n_0", &mut example_locals, &mut example_effects) {
773 Ok(arg_ty) => {
774 if let Err(e) = self.unify_with_record_coercion(&arg_ty, expected_ty) {
775 errors.push(mismatch_err(
776 "n_0", e, &self.u,
777 vec![format!("in example #{} for `{}`, argument {}", case_index + 1, fd.name, i + 1)],
778 ));
779 args_ok = false;
780 }
781 }
782 Err(e) => { errors.push(e); args_ok = false; }
783 }
784 }
785 if args_ok {
786 match self.check_expr(&ex.expected, "n_0", &mut example_locals, &mut example_effects) {
787 Ok(expected_ty) => {
788 if let Err(e) = self.unify_with_record_coercion(&expected_ty, &ret_ty) {
789 errors.push(mismatch_err(
790 "n_0", e, &self.u,
791 vec![format!("in example #{} for `{}`, expected value", case_index + 1, fd.name)],
792 ));
793 }
794 }
795 Err(e) => errors.push(e),
796 }
797 }
798 if let Some(e) = example_effects.concrete.iter().next() {
803 errors.push(TypeError::EffectNotDeclared {
804 at_node: "n_0".into(),
805 effect: e.pretty(),
806 });
807 }
808 }
809 }
810 }
811
812 errors.append(&mut self.recovered_errors);
814 if errors.is_empty() { Ok(scheme) } else { Err(errors) }
815 }
816
817 fn check_expr(
818 &mut self,
819 e: &a::CExpr,
820 node_id: &str,
821 locals: &mut IndexMap<String, Ty>,
822 effs: &mut EffectSet,
823 ) -> Result<Ty, TypeError> {
824 match e {
825 a::CExpr::Literal { value } => Ok(lit_type(value)),
826 a::CExpr::Var { name } => {
827 if let Some(t) = locals.get(name) {
828 return Ok(t.clone());
829 }
830 if let Some(scheme) = self.globals.get(name).cloned() {
831 return Ok(instantiate(&scheme, &mut self.u));
832 }
833 Err(TypeError::UnknownIdentifier { at_node: node_id.into(), name: name.clone() })
834 }
835 a::CExpr::Constructor { name, args } => self.check_constructor(name, args, node_id, locals, effs),
836 a::CExpr::Call { callee, args } => self.check_call(e, callee, args, node_id, locals, effs),
837 a::CExpr::Let { name, ty, value, body } => {
838 let v_ty = self.check_expr_recover(value, node_id, locals, effs);
842 if let Some(declared) = ty {
843 let d = ty_from_canon_env(declared, &[], &self.type_env);
844 if let Err(err) = self.unify_with_record_coercion(&v_ty, &d) {
845 return Err(mismatch_err(node_id, err, &self.u, vec![format!("in let `{}`", name)]));
846 }
847 }
848 let prev = locals.insert(name.clone(), v_ty);
849 let body_ty = self.check_expr(body, node_id, locals, effs)?;
850 match prev {
851 Some(p) => { locals.insert(name.clone(), p); }
852 None => { locals.shift_remove(name); }
853 }
854 Ok(body_ty)
855 }
856 a::CExpr::Match { scrutinee, arms } => {
857 let scrut_ty = self.check_expr(scrutinee, node_id, locals, effs)?;
858 if arms.is_empty() {
859 return Err(TypeError::NonExhaustiveMatch {
860 at_node: node_id.into(), missing: vec!["_".into()]
861 });
862 }
863 let result_ty = self.u.fresh();
864 for arm in arms {
865 let mut arm_locals = locals.clone();
866 self.bind_pattern(&arm.pattern, &scrut_ty, &mut arm_locals, node_id)?;
867 let arm_ty = self.check_expr(&arm.body, node_id, &mut arm_locals, effs)?;
868 if let Err(err) = self.unify_with_record_coercion(&arm_ty, &result_ty) {
869 return Err(mismatch_err(node_id, err, &self.u, vec!["in match arm".into()]));
870 }
871 }
872 Ok(result_ty)
873 }
874 a::CExpr::Block { statements, result } => {
875 for s in statements {
879 let _ = self.check_expr_recover(s, node_id, locals, effs);
880 }
881 self.check_expr(result, node_id, locals, effs)
882 }
883 a::CExpr::RecordLit { fields } => {
884 let mut tys = IndexMap::new();
885 for f in fields {
886 if tys.contains_key(&f.name) {
887 return Err(TypeError::DuplicateField {
888 at_node: node_id.into(), field: f.name.clone()
889 });
890 }
891 let ft = self.check_expr(&f.value, node_id, locals, effs)?;
892 tys.insert(f.name.clone(), ft);
893 }
894 Ok(Ty::Record(tys))
895 }
896 a::CExpr::TupleLit { items } => {
897 let mut ts = Vec::new();
898 for it in items { ts.push(self.check_expr(it, node_id, locals, effs)?); }
899 Ok(Ty::Tuple(ts))
900 }
901 a::CExpr::ListLit { items } => {
902 let elem = self.u.fresh();
903 for it in items {
904 let t = self.check_expr(it, node_id, locals, effs)?;
905 if let Err(err) = self.unify_with_record_coercion(&t, &elem) {
906 return Err(mismatch_err(node_id, err, &self.u, vec!["in list literal".into()]));
907 }
908 }
909 Ok(Ty::List(Box::new(elem)))
910 }
911 a::CExpr::FieldAccess { value, field } => {
912 let vt = self.check_expr(value, node_id, locals, effs)?;
913 let resolved = self.u.resolve(&vt);
914 let resolved = if let Ty::Con(_, _) = &resolved {
922 let unfolded = self.unfold_record_alias(resolved.clone());
923 if matches!(unfolded, Ty::Record(_)) {
924 unfolded
925 } else {
926 resolved
927 }
928 } else {
929 resolved
930 };
931 match resolved {
932 Ty::Record(fields) => fields.get(field).cloned()
933 .ok_or_else(|| TypeError::UnknownField {
934 at_node: node_id.into(),
935 record_type: Ty::Record(fields.clone()).pretty(),
936 field: field.clone(),
937 }),
938 other => Err(TypeError::TypeMismatch {
939 at_node: node_id.into(),
940 expected: "record".into(),
941 got: other.pretty(),
942 context: vec![format!("field access `.{}`", field)],
943 }),
944 }
945 }
946 a::CExpr::Lambda { params, return_type, effects: l_effects, body } => {
947 let param_tys: Vec<Ty> = params.iter().map(|p| ty_from_canon_env(&p.ty, &[], &self.type_env)).collect();
948 let ret_ty = ty_from_canon_env(return_type, &[], &self.type_env);
949 let declared = EffectSet {
950 concrete: {
951 let mut s = std::collections::BTreeSet::new();
952 for e in l_effects {
953 let arg = e.arg.as_ref().map(|a| match a {
954 a::EffectArg::Str { value } => crate::types::EffectArg::Str(value.clone()),
955 a::EffectArg::Int { value } => crate::types::EffectArg::Int(*value),
956 a::EffectArg::Ident { value } => crate::types::EffectArg::Ident(value.clone()),
957 });
958 s.insert(crate::types::EffectKind { name: e.name.clone(), arg });
959 }
960 s
961 },
962 var: None,
963 };
964 let mut inner_locals = locals.clone();
965 for (p, t) in params.iter().zip(param_tys.iter()) {
966 inner_locals.insert(p.name.clone(), t.clone());
967 }
968 let mut inner_effs = EffectSet::empty();
969 let body_ty = self.check_expr(body, node_id, &mut inner_locals, &mut inner_effs)?;
970 if let Err(err) = self.unify_with_record_coercion(&body_ty, &ret_ty) {
971 return Err(mismatch_err(node_id, err, &self.u, vec!["in lambda body".into()]));
972 }
973 if !inner_effs.is_subset(&declared) {
974 for e in inner_effs.concrete.iter() {
975 if !declared.concrete.iter().any(|d| d.subsumes(e)) {
976 return Err(TypeError::EffectNotDeclared {
977 at_node: node_id.into(),
978 effect: e.pretty(),
979 });
980 }
981 }
982 }
983 Ok(Ty::function(param_tys, declared, ret_ty))
984 }
985 a::CExpr::BinOp { op, lhs, rhs } => self.check_binop(op, lhs, rhs, node_id, locals, effs),
986 a::CExpr::UnaryOp { op, expr } => {
987 let t = self.check_expr(expr, node_id, locals, effs)?;
988 match op.as_str() {
989 "-" => {
990 let r = self.u.resolve(&t);
992 match r {
993 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) => Ok(t),
994 Ty::Var(_) => {
995 self.u.unify(&t, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![]))?;
997 Ok(Ty::int())
998 }
999 other => Err(TypeError::TypeMismatch {
1000 at_node: node_id.into(),
1001 expected: "Int or Float".into(),
1002 got: other.pretty(),
1003 context: vec!["unary `-`".into()],
1004 }),
1005 }
1006 }
1007 "not" => {
1008 self.u.unify(&t, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["unary `not`".into()]))?;
1009 Ok(Ty::bool())
1010 }
1011 other => panic!("unknown unary op: {other}"),
1012 }
1013 }
1014 a::CExpr::Return { value } => {
1015 self.check_expr(value, node_id, locals, effs)?;
1018 Ok(Ty::Never)
1019 }
1020 }
1021 }
1022
1023 fn check_binop(
1024 &mut self,
1025 op: &str,
1026 lhs: &a::CExpr,
1027 rhs: &a::CExpr,
1028 node_id: &str,
1029 locals: &mut IndexMap<String, Ty>,
1030 effs: &mut EffectSet,
1031 ) -> Result<Ty, TypeError> {
1032 let lt = self.check_expr(lhs, node_id, locals, effs)?;
1033 let rt = self.check_expr(rhs, node_id, locals, effs)?;
1034 match op {
1035 "+" => {
1036 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1044 let r = self.unfold_record_alias(self.u.resolve(<));
1045 match r {
1046 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) | Ty::Prim(Prim::Str) => Ok(lt),
1047 Ty::Var(_) => {
1048 self.u.unify(<, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1049 Ok(Ty::int())
1050 }
1051 other => Err(TypeError::TypeMismatch {
1052 at_node: node_id.into(),
1053 expected: "Int, Float, or Str".into(),
1054 got: other.pretty(),
1055 context: vec![format!("operator `{op}`")],
1056 }),
1057 }
1058 }
1059 "-" | "*" | "/" | "%" => {
1060 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1061 let r = self.unfold_record_alias(self.u.resolve(<));
1062 match r {
1063 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) => Ok(lt),
1064 Ty::Var(_) => {
1065 self.u.unify(<, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1066 Ok(Ty::int())
1067 }
1068 other => Err(TypeError::TypeMismatch {
1069 at_node: node_id.into(),
1070 expected: "Int or Float".into(),
1071 got: other.pretty(),
1072 context: vec![format!("operator `{op}`")],
1073 }),
1074 }
1075 }
1076 "==" | "!=" => {
1077 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1078 Ok(Ty::bool())
1079 }
1080 "<" | "<=" | ">" | ">=" => {
1081 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1082 let r = self.unfold_record_alias(self.u.resolve(<));
1083 match r {
1084 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) | Ty::Prim(Prim::Str) => Ok(Ty::bool()),
1085 Ty::Var(_) => {
1086 self.u.unify(<, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1087 Ok(Ty::bool())
1088 }
1089 other => Err(TypeError::TypeMismatch {
1090 at_node: node_id.into(),
1091 expected: "Int, Float, or Str".into(),
1092 got: other.pretty(),
1093 context: vec![format!("operator `{op}`")],
1094 }),
1095 }
1096 }
1097 "and" | "or" => {
1098 self.u.unify(<, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1099 self.u.unify(&rt, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1100 Ok(Ty::bool())
1101 }
1102 other => panic!("unknown binop: {other}"),
1103 }
1104 }
1105
1106 fn check_call(
1107 &mut self,
1108 call_expr: &a::CExpr,
1109 callee: &a::CExpr,
1110 args: &[a::CExpr],
1111 node_id: &str,
1112 locals: &mut IndexMap<String, Ty>,
1113 effs: &mut EffectSet,
1114 ) -> Result<Ty, TypeError> {
1115 let parse_call_ptr = if self.is_module_parse_call(callee) {
1123 Some(call_expr as *const a::CExpr as usize)
1124 } else {
1125 None
1126 };
1127 let callee_ty = self.check_expr(callee, node_id, locals, effs)?;
1128 let resolved = self.u.resolve(&callee_ty);
1129 match resolved {
1130 Ty::Function { params, effects, ret } => {
1131 if params.len() != args.len() {
1132 return Err(TypeError::ArityMismatch {
1133 at_node: node_id.into(),
1134 expected: params.len(),
1135 got: args.len(),
1136 });
1137 }
1138 for (i, (a, p)) in args.iter().zip(params.iter()).enumerate() {
1139 let at = self.check_expr(a, node_id, locals, effs)?;
1140 if let Err(err) = self.unify_with_record_coercion(&at, p) {
1141 return Err(mismatch_err(node_id, err, &self.u, vec![format!("argument {} of call", i + 1)]));
1142 }
1143 }
1144 if let a::CExpr::Var { name: callee_name } = callee {
1151 if let Some(callee_params) = self.fn_params.get(callee_name).cloned() {
1152 for (i, (param, arg)) in callee_params.iter().zip(args.iter()).enumerate() {
1153 if let a::TypeExpr::Refined { binding, predicate, .. } = ¶m.ty {
1154 let outcome = crate::discharge::try_discharge(
1155 predicate, binding, arg);
1156 if let crate::discharge::DischargeOutcome::Refuted { reason } = outcome {
1157 return Err(TypeError::RefinementViolation {
1158 at_node: node_id.into(),
1159 fn_name: callee_name.clone(),
1160 param_index: i,
1161 binding: binding.clone(),
1162 reason,
1163 });
1164 }
1165 }
1166 }
1167 }
1168 }
1169 let resolved_effects = self.u.resolve_effects(&effects);
1174 effs.extend(&resolved_effects);
1175 if let Some(ptr) = parse_call_ptr {
1183 self.pending_parse_calls.push((ptr, (*ret).clone()));
1184 }
1185 Ok(*ret)
1186 }
1187 Ty::Var(_) => {
1188 let mut p_tys = Vec::new();
1190 for a in args { p_tys.push(self.check_expr(a, node_id, locals, effs)?); }
1191 let r = self.u.fresh();
1192 let f = Ty::function(p_tys, EffectSet::empty(), r.clone());
1193 self.u.unify(&callee_ty, &f).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in call".into()]))?;
1194 Ok(r)
1195 }
1196 other => Err(TypeError::TypeMismatch {
1197 at_node: node_id.into(),
1198 expected: "function".into(),
1199 got: other.pretty(),
1200 context: vec!["in call".into()],
1201 }),
1202 }
1203 }
1204
1205 fn check_constructor(
1206 &mut self,
1207 name: &str,
1208 args: &[a::CExpr],
1209 node_id: &str,
1210 locals: &mut IndexMap<String, Ty>,
1211 effs: &mut EffectSet,
1212 ) -> Result<Ty, TypeError> {
1213 let owning = self.type_env.ctor_to_type.get(name).cloned()
1214 .ok_or_else(|| TypeError::UnknownVariant {
1215 at_node: node_id.into(),
1216 constructor: name.to_string(),
1217 })?;
1218 let def = self.type_env.types.get(&owning).cloned()
1219 .expect("ctor_to_type points to a real type");
1220 let variants = match &def.kind {
1221 TypeDefKind::Union(v) => v.clone(),
1222 _ => return Err(TypeError::UnknownVariant {
1223 at_node: node_id.into(),
1224 constructor: name.to_string(),
1225 }),
1226 };
1227 let mut subst = IndexMap::new();
1230 let mut con_args = Vec::with_capacity(def.params.len());
1231 for (i, _p) in def.params.iter().enumerate() {
1232 let fresh = self.u.fresh();
1233 subst.insert(i as u32, fresh.clone());
1234 con_args.push(fresh);
1235 }
1236 let payload = variants.get(name).cloned().flatten();
1237 match (payload, args) {
1238 (None, []) => Ok(Ty::Con(owning, con_args)),
1239 (Some(payload), args) => {
1240 let inst_payload = subst_vars(&payload, &subst, &IndexMap::new());
1241 let arg_count = match &inst_payload {
1242 Ty::Tuple(items) => items.len(),
1243 _ => 1,
1244 };
1245 if arg_count != args.len() {
1246 return Err(TypeError::ArityMismatch {
1247 at_node: node_id.into(),
1248 expected: arg_count,
1249 got: args.len(),
1250 });
1251 }
1252 if args.len() == 1 {
1253 let at = self.check_expr(&args[0], node_id, locals, effs)?;
1254 self.unify_with_record_coercion(&at, &inst_payload).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor `{}`", name)]))?;
1255 } else if let Ty::Tuple(items) = inst_payload {
1256 for (i, (a, t)) in args.iter().zip(items.iter()).enumerate() {
1257 let at = self.check_expr(a, node_id, locals, effs)?;
1258 self.unify_with_record_coercion(&at, t).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor `{}` arg {}", name, i + 1)]))?;
1259 }
1260 }
1261 Ok(Ty::Con(owning, con_args))
1262 }
1263 (None, _) => Err(TypeError::ArityMismatch {
1264 at_node: node_id.into(), expected: 0, got: args.len(),
1265 }),
1266 }
1267 }
1268
1269 fn bind_pattern(
1270 &mut self,
1271 pat: &a::Pattern,
1272 ty: &Ty,
1273 locals: &mut IndexMap<String, Ty>,
1274 node_id: &str,
1275 ) -> Result<(), TypeError> {
1276 match pat {
1277 a::Pattern::PWild => Ok(()),
1278 a::Pattern::PVar { name } => {
1279 locals.insert(name.clone(), ty.clone());
1280 Ok(())
1281 }
1282 a::Pattern::PLiteral { value } => {
1283 let lt = lit_type(value);
1284 self.unify_with_record_coercion(<, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in pattern".into()]))?;
1285 Ok(())
1286 }
1287 a::Pattern::PConstructor { name, args } => {
1288 let owning = self.type_env.ctor_to_type.get(name).cloned()
1290 .ok_or_else(|| TypeError::UnknownVariant {
1291 at_node: node_id.into(), constructor: name.clone(),
1292 })?;
1293 let def = self.type_env.types.get(&owning).cloned().unwrap();
1294 let mut subst = IndexMap::new();
1295 let mut con_args = Vec::new();
1296 for (i, _) in def.params.iter().enumerate() {
1297 let fresh = self.u.fresh();
1298 subst.insert(i as u32, fresh.clone());
1299 con_args.push(fresh);
1300 }
1301 let con_ty = Ty::Con(owning.clone(), con_args);
1302 self.unify_with_record_coercion(&con_ty, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor pattern `{}`", name)]))?;
1303 let payload = match &def.kind {
1304 TypeDefKind::Union(v) => v.get(name).cloned().flatten(),
1305 _ => None,
1306 };
1307 match (payload, args.as_slice()) {
1308 (None, []) => Ok(()),
1309 (Some(payload), args) => {
1310 let inst = subst_vars(&payload, &subst, &IndexMap::new());
1311 if args.len() == 1 {
1312 self.bind_pattern(&args[0], &inst, locals, node_id)?;
1313 } else if let Ty::Tuple(items) = inst {
1314 for (a, t) in args.iter().zip(items.iter()) {
1315 self.bind_pattern(a, t, locals, node_id)?;
1316 }
1317 }
1318 Ok(())
1319 }
1320 (None, _) => Err(TypeError::ArityMismatch {
1321 at_node: node_id.into(), expected: 0, got: args.len(),
1322 }),
1323 }
1324 }
1325 a::Pattern::PRecord { fields } => {
1326 let resolved = self.unfold_record_alias(self.u.resolve(ty));
1331 let rec = match resolved {
1332 Ty::Record(r) => r,
1333 _ => return Err(TypeError::TypeMismatch {
1334 at_node: node_id.into(),
1335 expected: "record".into(),
1336 got: ty.pretty(),
1337 context: vec!["in record pattern".into()],
1338 }),
1339 };
1340 for f in fields {
1341 let ft = rec.get(&f.name).cloned()
1342 .ok_or_else(|| TypeError::UnknownField {
1343 at_node: node_id.into(),
1344 record_type: Ty::Record(rec.clone()).pretty(),
1345 field: f.name.clone(),
1346 })?;
1347 self.bind_pattern(&f.pattern, &ft, locals, node_id)?;
1348 }
1349 Ok(())
1350 }
1351 a::Pattern::PTuple { items } => {
1352 if items.is_empty() {
1354 return self.unify_with_record_coercion(&Ty::Unit, ty)
1355 .map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in unit pattern".into()]));
1356 }
1357 let resolved = self.u.resolve(ty);
1358 let tup = match resolved {
1359 Ty::Tuple(t) => t,
1360 Ty::Var(_) => {
1361 let fresh: Vec<Ty> = items.iter().map(|_| self.u.fresh()).collect();
1362 let tup_ty = Ty::Tuple(fresh.clone());
1363 self.unify_with_record_coercion(&tup_ty, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in tuple pattern".into()]))?;
1364 fresh
1365 }
1366 other => {
1367 return Err(TypeError::TypeMismatch {
1368 at_node: node_id.into(),
1369 expected: "tuple".into(),
1370 got: other.pretty(),
1371 context: vec!["in tuple pattern".into()],
1372 });
1373 }
1374 };
1375 if tup.len() != items.len() {
1376 return Err(TypeError::ArityMismatch {
1377 at_node: node_id.into(), expected: tup.len(), got: items.len(),
1378 });
1379 }
1380 for (p, t) in items.iter().zip(tup.iter()) {
1381 self.bind_pattern(p, t, locals, node_id)?;
1382 }
1383 Ok(())
1384 }
1385 }
1386 }
1387}
1388
1389fn lit_type(l: &a::CLit) -> Ty {
1390 match l {
1391 a::CLit::Int { .. } => Ty::int(),
1392 a::CLit::Float { .. } => Ty::float(),
1393 a::CLit::Str { .. } => Ty::str(),
1394 a::CLit::Bytes { .. } => Ty::bytes(),
1395 a::CLit::Bool { .. } => Ty::bool(),
1396 a::CLit::Unit => Ty::Unit,
1397 }
1398}
1399
1400fn instantiate(s: &Scheme, u: &mut Unifier) -> Ty {
1401 let mut ty_subst = IndexMap::new();
1402 for v in &s.vars { ty_subst.insert(*v, u.fresh()); }
1403 let mut eff_subst = IndexMap::new();
1404 for v in &s.eff_vars { eff_subst.insert(*v, u.fresh_eff_id()); }
1405 subst_vars(&s.ty, &ty_subst, &eff_subst)
1406}
1407
1408fn subst_vars(
1409 t: &Ty,
1410 subst: &IndexMap<TyVarId, Ty>,
1411 eff_subst: &IndexMap<u32, u32>,
1412) -> Ty {
1413 match t {
1414 Ty::Var(v) => subst.get(v).cloned().unwrap_or_else(|| Ty::Var(*v)),
1415 Ty::Prim(_) | Ty::Unit | Ty::Never => t.clone(),
1416 Ty::List(inner) => Ty::List(Box::new(subst_vars(inner, subst, eff_subst))),
1417 Ty::Tuple(items) => Ty::Tuple(items.iter().map(|t| subst_vars(t, subst, eff_subst)).collect()),
1418 Ty::Record(fs) => {
1419 let mut out = IndexMap::new();
1420 for (k, v) in fs { out.insert(k.clone(), subst_vars(v, subst, eff_subst)); }
1421 Ty::Record(out)
1422 }
1423 Ty::Con(n, args) => Ty::Con(n.clone(),
1424 args.iter().map(|t| subst_vars(t, subst, eff_subst)).collect()),
1425 Ty::Function { params, effects, ret } => {
1426 let new_effects = EffectSet {
1429 concrete: effects.concrete.clone(),
1430 var: effects.var.and_then(|v| eff_subst.get(&v).copied()).or(effects.var),
1431 };
1432 Ty::Function {
1433 params: params.iter().map(|t| subst_vars(t, subst, eff_subst)).collect(),
1434 effects: new_effects,
1435 ret: Box::new(subst_vars(ret, subst, eff_subst)),
1436 }
1437 }
1438 }
1439}
1440
1441fn mismatch_err(node_id: &str, e: UnifyError, u: &Unifier, context: Vec<String>) -> TypeError {
1442 match e {
1443 UnifyError::Mismatch { a, b } => TypeError::TypeMismatch {
1444 at_node: node_id.into(),
1445 expected: u.resolve(&b).pretty(),
1446 got: u.resolve(&a).pretty(),
1447 context,
1448 },
1449 UnifyError::Infinite { .. } => TypeError::InfiniteType { at_node: node_id.into() },
1450 UnifyError::EffectMismatch { a, b } => {
1451 let render = |e: &EffectSet| -> String {
1456 let mut parts: Vec<String> = e.concrete.iter()
1457 .map(crate::types::EffectKind::pretty).collect();
1458 if let Some(v) = e.var { parts.push(format!("?e{}", v)); }
1459 if parts.is_empty() { "[]".into() } else { format!("[{}]", parts.join(", ")) }
1460 };
1461 TypeError::EffectRowMismatch {
1462 at_node: node_id.into(),
1463 expected: render(&b),
1464 got: render(&a),
1465 context,
1466 }
1467 }
1468 }
1469}