1use crate::builtins::{module_for_import, module_scope};
5use crate::env::{TypeDefKind, TypeEnv, ty_from_canon_env};
6use crate::error::{PositionedError, TypeError};
7use crate::position::Position;
8use crate::types::*;
9use crate::unifier::{UnifyError, Unifier};
10use indexmap::IndexMap;
11use lex_ast as a;
12use std::collections::{BTreeMap, HashMap};
13
14type FieldSchema = (Vec<String>, Vec<(String, String)>);
17
18pub struct ProgramTypes {
20 pub fn_signatures: IndexMap<String, Scheme>,
21 pub type_env: TypeEnv,
22 pub parse_required_fields: HashMap<usize, Vec<String>>,
31 pub parse_type_schemas: HashMap<usize, Vec<(String, String)>>,
36}
37
38pub fn check_program_with_positions(
52 stages: &[a::Stage],
53 positions: &BTreeMap<String, Position>,
54) -> Result<ProgramTypes, Vec<PositionedError>> {
55 check_program_inner(stages, Some(positions))
56 .map_err(|errs| errs.into_iter().map(|(e, fn_name)| {
57 let pos = fn_name.as_deref().and_then(|n| positions.get(n)).cloned();
58 PositionedError::new(e, pos)
59 }).collect())
60}
61
62pub fn check_program(stages: &[a::Stage]) -> Result<ProgramTypes, Vec<TypeError>> {
63 check_program_inner(stages, None)
64 .map_err(|errs| errs.into_iter().map(|(e, _)| e).collect())
65}
66
67fn check_program_inner(
68 stages: &[a::Stage],
69 _positions: Option<&BTreeMap<String, Position>>,
70) -> Result<ProgramTypes, Vec<(TypeError, Option<String>)>> {
71 let mut tcx = Checker::new();
72 let mut errors: Vec<(TypeError, Option<String>)> = Vec::new();
75
76 for stage in stages {
78 if let a::Stage::Import(i) = stage {
79 if let Some(mod_name) = module_for_import(&i.reference) {
80 if let Some(ty) = module_scope(mod_name, &tcx.type_env) {
81 tcx.globals.insert(i.alias.clone(), Scheme {
82 vars: collect_vars(&ty),
86 eff_vars: collect_eff_vars(&ty),
87 ty,
88 });
89 tcx.module_aliases.insert(i.alias.clone(), mod_name.to_string());
90 }
91 }
92 }
93 }
94
95 for stage in stages {
97 if let a::Stage::TypeDecl(td) = stage {
98 if let Err(e) = tcx.type_env.add_user_type(&td.name, td.clone()) {
99 errors.push((TypeError::RecursiveTypeWithoutConstructor {
100 at_node: "n_0".into(),
101 name: e,
102 }, None));
103 }
104 }
105 }
106
107 for stage in stages {
109 if let a::Stage::FnDecl(fd) = stage {
110 let scheme = function_scheme(fd, &tcx.type_env);
111 tcx.globals.insert(fd.name.clone(), scheme);
112 tcx.fn_params.insert(fd.name.clone(), fd.params.clone());
116 }
117 }
118
119 let mut signatures = IndexMap::new();
124 for stage in stages {
125 if let a::Stage::FnDecl(fd) = stage {
126 match tcx.check_fn(fd) {
127 Ok(scheme) => { signatures.insert(fd.name.clone(), scheme); }
128 Err(es) => {
129 errors.extend(es.into_iter().map(|e| (e, Some(fd.name.clone()))));
130 }
131 }
132 }
133 }
134
135 if errors.is_empty() {
136 let mut parse_required_fields = HashMap::new();
142 let mut parse_type_schemas = HashMap::new();
143 for (call_ptr, ret_ty) in &tcx.pending_parse_calls {
144 if let Some((fields, schema)) = extract_record_fields_and_schema(&tcx.u, &tcx.type_env, ret_ty) {
145 parse_required_fields.insert(*call_ptr, fields);
146 parse_type_schemas.insert(*call_ptr, schema);
147 }
148 }
149 Ok(ProgramTypes {
150 fn_signatures: signatures,
151 type_env: tcx.type_env,
152 parse_required_fields,
153 parse_type_schemas,
154 })
155 } else {
156 Err(errors)
157 }
158}
159
160pub fn check_and_rewrite_program(
166 stages: &mut [a::Stage],
167) -> Result<ProgramTypes, Vec<TypeError>> {
168 let pt = check_program(&*stages)?;
173 if !pt.parse_required_fields.is_empty() {
174 rewrite_parse_calls(stages, &pt.parse_required_fields, &pt.parse_type_schemas);
175 }
176 Ok(pt)
177}
178
179fn rewrite_parse_calls(
193 stages: &mut [a::Stage],
194 required: &HashMap<usize, Vec<String>>,
195 schemas: &HashMap<usize, Vec<(String, String)>>,
196) {
197 for stage in stages.iter_mut() {
198 if let a::Stage::FnDecl(fd) = stage {
199 rewrite_in_expr(&mut fd.body, required, schemas);
200 }
201 }
202}
203
204fn rewrite_in_expr(
205 expr: &mut a::CExpr,
206 required: &HashMap<usize, Vec<String>>,
207 schemas: &HashMap<usize, Vec<(String, String)>>,
208) {
209 let ptr = expr as *const a::CExpr as usize;
210 let do_rewrite = required.get(&ptr).cloned();
211 let do_schema = schemas.get(&ptr).cloned();
212 match expr {
218 a::CExpr::Call { callee, args } => {
219 rewrite_in_expr(callee, required, schemas);
220 for a in args.iter_mut() { rewrite_in_expr(a, required, schemas); }
221 }
222 a::CExpr::Let { value, body, .. } => {
223 rewrite_in_expr(value, required, schemas);
224 rewrite_in_expr(body, required, schemas);
225 }
226 a::CExpr::Match { scrutinee, arms } => {
227 rewrite_in_expr(scrutinee, required, schemas);
228 for arm in arms.iter_mut() { rewrite_in_expr(&mut arm.body, required, schemas); }
229 }
230 a::CExpr::Block { statements, result } => {
231 for s in statements.iter_mut() { rewrite_in_expr(s, required, schemas); }
232 rewrite_in_expr(result, required, schemas);
233 }
234 a::CExpr::Constructor { args, .. } => {
235 for a in args.iter_mut() { rewrite_in_expr(a, required, schemas); }
236 }
237 a::CExpr::RecordLit { fields } => {
238 for f in fields.iter_mut() { rewrite_in_expr(&mut f.value, required, schemas); }
239 }
240 a::CExpr::TupleLit { items } | a::CExpr::ListLit { items } => {
241 for it in items.iter_mut() { rewrite_in_expr(it, required, schemas); }
242 }
243 a::CExpr::FieldAccess { value, .. } => rewrite_in_expr(value, required, schemas),
244 a::CExpr::Lambda { body, .. } => rewrite_in_expr(body, required, schemas),
245 a::CExpr::BinOp { lhs, rhs, .. } => {
246 rewrite_in_expr(lhs, required, schemas);
247 rewrite_in_expr(rhs, required, schemas);
248 }
249 a::CExpr::UnaryOp { expr, .. } => rewrite_in_expr(expr, required, schemas),
250 a::CExpr::Return { value } => rewrite_in_expr(value, required, schemas),
251 a::CExpr::Literal { .. } | a::CExpr::Var { .. } => {}
252 }
253 if let Some(fields) = do_rewrite {
254 match expr {
255 a::CExpr::Call { callee, args } => {
256 if let a::CExpr::FieldAccess { field, .. } = callee.as_mut() {
257 debug_assert_eq!(field, "parse",
258 "rewrite_in_expr: only `.parse` calls should be in the table");
259 *field = "parse_strict_typed".to_string();
262 }
263 args.push(a::CExpr::ListLit {
265 items: fields.into_iter()
266 .map(|f| a::CExpr::Literal {
267 value: a::CLit::Str { value: f },
268 })
269 .collect(),
270 });
271 let schema = do_schema.unwrap_or_default();
273 args.push(a::CExpr::ListLit {
274 items: schema.into_iter()
275 .map(|(name, tag)| a::CExpr::TupleLit {
276 items: vec![
277 a::CExpr::Literal { value: a::CLit::Str { value: name } },
278 a::CExpr::Literal { value: a::CLit::Str { value: tag } },
279 ],
280 })
281 .collect(),
282 });
283 }
284 _ => unreachable!("rewrite table key must point to a Call expression"),
285 }
286 }
287}
288
289fn extract_record_fields_and_schema(
295 u: &Unifier,
296 env: &TypeEnv,
297 ty: &Ty,
298) -> Option<FieldSchema> {
299 let resolved = u.resolve(ty);
300 let Ty::Con(ref name, ref args) = resolved else { return None; };
301 if name != "Result" || args.len() != 2 { return None; }
302 let ok_ty = u.resolve(&args[0]);
303 let unfolded = unfold_record_alias_static(env, ok_ty);
304 if let Ty::Record(fields) = unfolded {
305 let names: Vec<String> = fields.keys().cloned().collect();
306 let schema: Vec<(String, String)> = fields.iter()
307 .map(|(k, v)| (k.clone(), ty_to_tag(u, v)))
308 .collect();
309 Some((names, schema))
310 } else {
311 None
312 }
313}
314
315fn ty_to_tag(u: &Unifier, ty: &Ty) -> String {
319 let resolved = u.resolve(ty);
320 match &resolved {
321 Ty::Prim(Prim::Int) => "Int".to_string(),
322 Ty::Prim(Prim::Float) => "Float".to_string(),
323 Ty::Prim(Prim::Bool) => "Bool".to_string(),
324 Ty::Prim(Prim::Str) => "Str".to_string(),
325 Ty::Con(name, args) if name == "Option" && args.len() == 1 => {
326 format!("Option[{}]", ty_to_tag(u, &args[0]))
327 }
328 Ty::List(inner) => {
329 format!("List[{}]", ty_to_tag(u, inner))
330 }
331 Ty::Record(_) => "Record".to_string(),
332 _ => "Any".to_string(),
333 }
334}
335
336fn unfold_record_alias_static(env: &TypeEnv, ty: Ty) -> Ty {
342 if let Ty::Con(ref n, ref args) = ty {
343 if let Some(td) = env.types.get(n) {
344 if let TypeDefKind::Alias(inner) = &td.kind {
345 if td.params.len() != args.len() {
346 return ty;
347 }
348 if td.params.is_empty() {
349 return inner.clone();
350 }
351 let mut subst = IndexMap::new();
352 for (i, a) in args.iter().enumerate() {
353 subst.insert(i as u32, a.clone());
354 }
355 return subst_vars(inner, &subst, &IndexMap::new());
356 }
357 }
358 }
359 ty
360}
361
362fn collect_vars(t: &Ty) -> Vec<TyVarId> {
363 let mut out = Vec::new();
364 fn walk(t: &Ty, out: &mut Vec<TyVarId>) {
365 match t {
366 Ty::Var(v) => { if !out.contains(v) { out.push(*v); } }
367 Ty::Prim(_) | Ty::Unit | Ty::Never => {}
368 Ty::List(inner) => walk(inner, out),
369 Ty::Tuple(items) => for it in items { walk(it, out); },
370 Ty::Record(fs) => for v in fs.values() { walk(v, out); },
371 Ty::Con(_, args) => for a in args { walk(a, out); },
372 Ty::Function { params, ret, .. } => {
373 for p in params { walk(p, out); }
374 walk(ret, out);
375 }
376 }
377 }
378 walk(t, &mut out);
379 out
380}
381
382fn collect_eff_vars(t: &Ty) -> Vec<u32> {
386 let mut out = Vec::new();
387 fn walk(t: &Ty, out: &mut Vec<u32>) {
388 match t {
389 Ty::Var(_) | Ty::Prim(_) | Ty::Unit | Ty::Never => {}
390 Ty::List(inner) => walk(inner, out),
391 Ty::Tuple(items) => for it in items { walk(it, out); },
392 Ty::Record(fs) => for v in fs.values() { walk(v, out); },
393 Ty::Con(_, args) => for a in args { walk(a, out); },
394 Ty::Function { params, effects, ret } => {
395 if let Some(v) = effects.var {
396 if !out.contains(&v) { out.push(v); }
397 }
398 for p in params { walk(p, out); }
399 walk(ret, out);
400 }
401 }
402 }
403 walk(t, &mut out);
404 out
405}
406
407fn function_scheme(fd: &a::FnDecl, env: &TypeEnv) -> Scheme {
408 let params: Vec<Ty> = fd.params.iter().map(|p| ty_from_canon_env(&p.ty, &fd.type_params, env)).collect();
410 let ret = ty_from_canon_env(&fd.return_type, &fd.type_params, env);
411 let effects = EffectSet {
415 concrete: {
416 let mut s = std::collections::BTreeSet::new();
417 for e in &fd.effects {
418 let arg = e.arg.as_ref().map(|a| match a {
419 a::EffectArg::Str { value } => crate::types::EffectArg::Str(value.clone()),
420 a::EffectArg::Int { value } => crate::types::EffectArg::Int(*value),
421 a::EffectArg::Ident { value } => crate::types::EffectArg::Ident(value.clone()),
422 });
423 s.insert(crate::types::EffectKind { name: e.name.clone(), arg });
424 }
425 s
426 },
427 var: None,
428 };
429 let ty = Ty::Function { params, effects, ret: Box::new(ret) };
430 let vars: Vec<TyVarId> = (0..fd.type_params.len() as u32).collect();
431 Scheme { vars, eff_vars: Vec::new(), ty }
435}
436
437struct Checker {
438 u: Unifier,
439 type_env: TypeEnv,
440 globals: IndexMap<String, Scheme>,
441 module_aliases: IndexMap<String, String>,
445 pending_parse_calls: Vec<(usize, Ty)>,
453 fn_params: IndexMap<String, Vec<a::Param>>,
459}
460
461impl Checker {
462 fn new() -> Self {
463 Self {
464 u: Unifier::new(),
465 type_env: TypeEnv::new_with_builtins(),
466 globals: IndexMap::new(),
467 module_aliases: IndexMap::new(),
468 pending_parse_calls: Vec::new(),
469 fn_params: IndexMap::new(),
470 }
471 }
472
473 fn unfold_record_alias(&self, ty: Ty) -> Ty {
483 if let Ty::Con(ref n, ref args) = ty {
484 if let Some(td) = self.type_env.types.get(n) {
485 if let TypeDefKind::Alias(inner) = &td.kind {
486 if td.params.len() != args.len() {
487 return ty;
488 }
489 if td.params.is_empty() {
490 return inner.clone();
491 }
492 let mut subst = IndexMap::new();
493 for (i, a) in args.iter().enumerate() {
494 subst.insert(i as u32, a.clone());
495 }
496 return subst_vars(inner, &subst, &IndexMap::new());
497 }
498 }
499 }
500 ty
501 }
502
503 fn is_alias_con(&self, ty: &Ty) -> bool {
510 if let Ty::Con(name, args) = ty {
511 if let Some(td) = self.type_env.types.get(name) {
512 if matches!(td.kind, TypeDefKind::Alias(_))
513 && td.params.len() == args.len()
514 {
515 return true;
516 }
517 }
518 }
519 false
520 }
521
522 fn is_module_parse_call(&self, callee: &a::CExpr) -> bool {
527 if let a::CExpr::FieldAccess { value, field } = callee {
528 if field != "parse" { return false; }
529 if let a::CExpr::Var { name } = value.as_ref() {
530 if let Some(module) = self.module_aliases.get(name) {
531 return matches!(module.as_str(), "json" | "toml" | "yaml");
532 }
533 }
534 }
535 false
536 }
537
538 fn unify_with_record_coercion(&mut self, a: &Ty, b: &Ty) -> Result<(), UnifyError> {
550 let a = self.u.resolve(a);
551 let b = self.u.resolve(b);
552 self.unify_coerce_inner(a, b)
553 }
554
555 fn unify_coerce_inner(&mut self, a: Ty, b: Ty) -> Result<(), UnifyError> {
556 let (a, b) = match (&a, &b) {
582 (Ty::Con(n1, _), Ty::Con(n2, _)) if n1 == n2 => (a, b),
583 (Ty::Var(_), _) | (_, Ty::Var(_)) => (a, b),
584 (Ty::Con(_, _), Ty::Con(_, _))
585 if self.is_alias_con(&a) && self.is_alias_con(&b) =>
586 {
587 (a, b)
588 }
589 _ => {
590 let a_u = if let Ty::Con(_, _) = &a {
591 self.unfold_record_alias(a.clone())
592 } else {
593 a
594 };
595 let b_u = if let Ty::Con(_, _) = &b {
596 self.unfold_record_alias(b.clone())
597 } else {
598 b
599 };
600 (a_u, b_u)
601 }
602 };
603
604 match (&a, &b) {
605 (Ty::Record(fa), Ty::Record(fb)) => {
606 if fa.len() != fb.len() {
607 return Err(UnifyError::Mismatch { a: a.clone(), b: b.clone() });
608 }
609 for (k, va) in fa.clone() {
610 match fb.get(&k) {
611 Some(vb) => self.unify_coerce_inner(va, vb.clone())?,
612 None => return Err(UnifyError::Mismatch { a: a.clone(), b: b.clone() }),
613 }
614 }
615 Ok(())
616 }
617 (Ty::List(ta), Ty::List(tb)) => {
618 self.unify_coerce_inner((**ta).clone(), (**tb).clone())
619 }
620 (Ty::Tuple(xs), Ty::Tuple(ys)) if xs.len() == ys.len() => {
621 for (x, y) in xs.clone().into_iter().zip(ys.clone()) {
622 self.unify_coerce_inner(x, y)?;
623 }
624 Ok(())
625 }
626 (Ty::Con(n1, a1), Ty::Con(n2, a2)) if n1 == n2 && a1.len() == a2.len() => {
629 for (x, y) in a1.clone().into_iter().zip(a2.clone()) {
630 self.unify_coerce_inner(x, y)?;
631 }
632 Ok(())
633 }
634 (Ty::Function { params: pa, effects: ea, ret: ra },
639 Ty::Function { params: pb, effects: eb, ret: rb })
640 if pa.len() == pb.len() => {
641 for (x, y) in pa.clone().into_iter().zip(pb.clone()) {
642 self.unify_coerce_inner(x, y)?;
643 }
644 self.u.unify_effects(ea, eb)?;
649 self.unify_coerce_inner((**ra).clone(), (**rb).clone())
650 }
651 _ => self.u.unify(&a, &b),
652 }
653 }
654
655 fn check_fn(&mut self, fd: &a::FnDecl) -> Result<Scheme, Vec<TypeError>> {
656 let scheme = function_scheme(fd, &self.type_env);
658 let (param_tys, declared_effects, ret_ty) = match instantiate(&scheme, &mut self.u) {
659 Ty::Function { params, effects, ret } => (params, effects, *ret),
660 _ => unreachable!(),
661 };
662
663 let mut locals: IndexMap<String, Ty> = IndexMap::new();
664 for (p, t) in fd.params.iter().zip(param_tys.iter()) {
665 locals.insert(p.name.clone(), t.clone());
666 }
667
668 let mut errors: Vec<TypeError> = Vec::new();
672 let mut inferred_effects = EffectSet::empty();
673
674 let body_ok = match self.check_expr(&fd.body, "n_0", &mut locals, &mut inferred_effects) {
676 Ok(body_ty) => {
677 if let Err(e) = self.unify_with_record_coercion(&body_ty, &ret_ty) {
683 errors.push(mismatch_err("n_0", e, &self.u, vec![format!("in function `{}`", fd.name)]));
684 false
685 } else {
686 true
687 }
688 }
689 Err(e) => { errors.push(e); false }
690 };
691
692 if body_ok && !inferred_effects.is_subset(&declared_effects) {
693 for e in inferred_effects.concrete.iter() {
694 if !declared_effects.concrete.iter().any(|d| d.subsumes(e)) {
695 errors.push(TypeError::EffectNotDeclared {
696 at_node: "n_0".into(),
697 effect: e.pretty(),
698 });
699 break;
700 }
701 }
702 }
703
704 if !fd.examples.is_empty() {
709 if !declared_effects.concrete.is_empty() {
710 errors.push(TypeError::ExamplesOnEffectfulFn {
711 at_node: "n_0".into(),
712 fn_name: fd.name.clone(),
713 });
714 } else {
715 for (case_index, ex) in fd.examples.iter().enumerate() {
716 if ex.args.len() != param_tys.len() {
717 errors.push(TypeError::ExampleArityMismatch {
718 at_node: "n_0".into(),
719 fn_name: fd.name.clone(),
720 case_index,
721 expected: param_tys.len(),
722 got: ex.args.len(),
723 });
724 continue;
725 }
726 let mut example_locals: IndexMap<String, Ty> = IndexMap::new();
727 let mut example_effects = EffectSet::empty();
728 let mut args_ok = true;
729 for (i, (arg, expected_ty)) in
730 ex.args.iter().zip(param_tys.iter()).enumerate()
731 {
732 match self.check_expr(arg, "n_0", &mut example_locals, &mut example_effects) {
733 Ok(arg_ty) => {
734 if let Err(e) = self.unify_with_record_coercion(&arg_ty, expected_ty) {
735 errors.push(mismatch_err(
736 "n_0", e, &self.u,
737 vec![format!("in example #{} for `{}`, argument {}", case_index + 1, fd.name, i + 1)],
738 ));
739 args_ok = false;
740 }
741 }
742 Err(e) => { errors.push(e); args_ok = false; }
743 }
744 }
745 if args_ok {
746 match self.check_expr(&ex.expected, "n_0", &mut example_locals, &mut example_effects) {
747 Ok(expected_ty) => {
748 if let Err(e) = self.unify_with_record_coercion(&expected_ty, &ret_ty) {
749 errors.push(mismatch_err(
750 "n_0", e, &self.u,
751 vec![format!("in example #{} for `{}`, expected value", case_index + 1, fd.name)],
752 ));
753 }
754 }
755 Err(e) => errors.push(e),
756 }
757 }
758 if let Some(e) = example_effects.concrete.iter().next() {
763 errors.push(TypeError::EffectNotDeclared {
764 at_node: "n_0".into(),
765 effect: e.pretty(),
766 });
767 }
768 }
769 }
770 }
771
772 if errors.is_empty() { Ok(scheme) } else { Err(errors) }
773 }
774
775 fn check_expr(
776 &mut self,
777 e: &a::CExpr,
778 node_id: &str,
779 locals: &mut IndexMap<String, Ty>,
780 effs: &mut EffectSet,
781 ) -> Result<Ty, TypeError> {
782 match e {
783 a::CExpr::Literal { value } => Ok(lit_type(value)),
784 a::CExpr::Var { name } => {
785 if let Some(t) = locals.get(name) {
786 return Ok(t.clone());
787 }
788 if let Some(scheme) = self.globals.get(name).cloned() {
789 return Ok(instantiate(&scheme, &mut self.u));
790 }
791 Err(TypeError::UnknownIdentifier { at_node: node_id.into(), name: name.clone() })
792 }
793 a::CExpr::Constructor { name, args } => self.check_constructor(name, args, node_id, locals, effs),
794 a::CExpr::Call { callee, args } => self.check_call(e, callee, args, node_id, locals, effs),
795 a::CExpr::Let { name, ty, value, body } => {
796 let v_ty = self.check_expr(value, node_id, locals, effs)?;
797 if let Some(declared) = ty {
798 let d = ty_from_canon_env(declared, &[], &self.type_env);
799 if let Err(err) = self.unify_with_record_coercion(&v_ty, &d) {
800 return Err(mismatch_err(node_id, err, &self.u, vec![format!("in let `{}`", name)]));
801 }
802 }
803 let prev = locals.insert(name.clone(), v_ty);
804 let body_ty = self.check_expr(body, node_id, locals, effs)?;
805 match prev {
806 Some(p) => { locals.insert(name.clone(), p); }
807 None => { locals.shift_remove(name); }
808 }
809 Ok(body_ty)
810 }
811 a::CExpr::Match { scrutinee, arms } => {
812 let scrut_ty = self.check_expr(scrutinee, node_id, locals, effs)?;
813 if arms.is_empty() {
814 return Err(TypeError::NonExhaustiveMatch {
815 at_node: node_id.into(), missing: vec!["_".into()]
816 });
817 }
818 let result_ty = self.u.fresh();
819 for arm in arms {
820 let mut arm_locals = locals.clone();
821 self.bind_pattern(&arm.pattern, &scrut_ty, &mut arm_locals, node_id)?;
822 let arm_ty = self.check_expr(&arm.body, node_id, &mut arm_locals, effs)?;
823 if let Err(err) = self.unify_with_record_coercion(&arm_ty, &result_ty) {
824 return Err(mismatch_err(node_id, err, &self.u, vec!["in match arm".into()]));
825 }
826 }
827 Ok(result_ty)
828 }
829 a::CExpr::Block { statements, result } => {
830 for s in statements {
831 self.check_expr(s, node_id, locals, effs)?;
832 }
833 self.check_expr(result, node_id, locals, effs)
834 }
835 a::CExpr::RecordLit { fields } => {
836 let mut tys = IndexMap::new();
837 for f in fields {
838 if tys.contains_key(&f.name) {
839 return Err(TypeError::DuplicateField {
840 at_node: node_id.into(), field: f.name.clone()
841 });
842 }
843 let ft = self.check_expr(&f.value, node_id, locals, effs)?;
844 tys.insert(f.name.clone(), ft);
845 }
846 Ok(Ty::Record(tys))
847 }
848 a::CExpr::TupleLit { items } => {
849 let mut ts = Vec::new();
850 for it in items { ts.push(self.check_expr(it, node_id, locals, effs)?); }
851 Ok(Ty::Tuple(ts))
852 }
853 a::CExpr::ListLit { items } => {
854 let elem = self.u.fresh();
855 for it in items {
856 let t = self.check_expr(it, node_id, locals, effs)?;
857 if let Err(err) = self.unify_with_record_coercion(&t, &elem) {
858 return Err(mismatch_err(node_id, err, &self.u, vec!["in list literal".into()]));
859 }
860 }
861 Ok(Ty::List(Box::new(elem)))
862 }
863 a::CExpr::FieldAccess { value, field } => {
864 let vt = self.check_expr(value, node_id, locals, effs)?;
865 let resolved = self.u.resolve(&vt);
866 let resolved = if let Ty::Con(_, _) = &resolved {
874 let unfolded = self.unfold_record_alias(resolved.clone());
875 if matches!(unfolded, Ty::Record(_)) {
876 unfolded
877 } else {
878 resolved
879 }
880 } else {
881 resolved
882 };
883 match resolved {
884 Ty::Record(fields) => fields.get(field).cloned()
885 .ok_or_else(|| TypeError::UnknownField {
886 at_node: node_id.into(),
887 record_type: Ty::Record(fields.clone()).pretty(),
888 field: field.clone(),
889 }),
890 other => Err(TypeError::TypeMismatch {
891 at_node: node_id.into(),
892 expected: "record".into(),
893 got: other.pretty(),
894 context: vec![format!("field access `.{}`", field)],
895 }),
896 }
897 }
898 a::CExpr::Lambda { params, return_type, effects: l_effects, body } => {
899 let param_tys: Vec<Ty> = params.iter().map(|p| ty_from_canon_env(&p.ty, &[], &self.type_env)).collect();
900 let ret_ty = ty_from_canon_env(return_type, &[], &self.type_env);
901 let declared = EffectSet {
902 concrete: {
903 let mut s = std::collections::BTreeSet::new();
904 for e in l_effects {
905 let arg = e.arg.as_ref().map(|a| match a {
906 a::EffectArg::Str { value } => crate::types::EffectArg::Str(value.clone()),
907 a::EffectArg::Int { value } => crate::types::EffectArg::Int(*value),
908 a::EffectArg::Ident { value } => crate::types::EffectArg::Ident(value.clone()),
909 });
910 s.insert(crate::types::EffectKind { name: e.name.clone(), arg });
911 }
912 s
913 },
914 var: None,
915 };
916 let mut inner_locals = locals.clone();
917 for (p, t) in params.iter().zip(param_tys.iter()) {
918 inner_locals.insert(p.name.clone(), t.clone());
919 }
920 let mut inner_effs = EffectSet::empty();
921 let body_ty = self.check_expr(body, node_id, &mut inner_locals, &mut inner_effs)?;
922 if let Err(err) = self.unify_with_record_coercion(&body_ty, &ret_ty) {
923 return Err(mismatch_err(node_id, err, &self.u, vec!["in lambda body".into()]));
924 }
925 if !inner_effs.is_subset(&declared) {
926 for e in inner_effs.concrete.iter() {
927 if !declared.concrete.iter().any(|d| d.subsumes(e)) {
928 return Err(TypeError::EffectNotDeclared {
929 at_node: node_id.into(),
930 effect: e.pretty(),
931 });
932 }
933 }
934 }
935 Ok(Ty::function(param_tys, declared, ret_ty))
936 }
937 a::CExpr::BinOp { op, lhs, rhs } => self.check_binop(op, lhs, rhs, node_id, locals, effs),
938 a::CExpr::UnaryOp { op, expr } => {
939 let t = self.check_expr(expr, node_id, locals, effs)?;
940 match op.as_str() {
941 "-" => {
942 let r = self.u.resolve(&t);
944 match r {
945 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) => Ok(t),
946 Ty::Var(_) => {
947 self.u.unify(&t, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![]))?;
949 Ok(Ty::int())
950 }
951 other => Err(TypeError::TypeMismatch {
952 at_node: node_id.into(),
953 expected: "Int or Float".into(),
954 got: other.pretty(),
955 context: vec!["unary `-`".into()],
956 }),
957 }
958 }
959 "not" => {
960 self.u.unify(&t, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["unary `not`".into()]))?;
961 Ok(Ty::bool())
962 }
963 other => panic!("unknown unary op: {other}"),
964 }
965 }
966 a::CExpr::Return { value } => {
967 self.check_expr(value, node_id, locals, effs)?;
970 Ok(Ty::Never)
971 }
972 }
973 }
974
975 fn check_binop(
976 &mut self,
977 op: &str,
978 lhs: &a::CExpr,
979 rhs: &a::CExpr,
980 node_id: &str,
981 locals: &mut IndexMap<String, Ty>,
982 effs: &mut EffectSet,
983 ) -> Result<Ty, TypeError> {
984 let lt = self.check_expr(lhs, node_id, locals, effs)?;
985 let rt = self.check_expr(rhs, node_id, locals, effs)?;
986 match op {
987 "+" => {
988 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
996 let r = self.unfold_record_alias(self.u.resolve(<));
997 match r {
998 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) | Ty::Prim(Prim::Str) => Ok(lt),
999 Ty::Var(_) => {
1000 self.u.unify(<, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1001 Ok(Ty::int())
1002 }
1003 other => Err(TypeError::TypeMismatch {
1004 at_node: node_id.into(),
1005 expected: "Int, Float, or Str".into(),
1006 got: other.pretty(),
1007 context: vec![format!("operator `{op}`")],
1008 }),
1009 }
1010 }
1011 "-" | "*" | "/" | "%" => {
1012 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1013 let r = self.unfold_record_alias(self.u.resolve(<));
1014 match r {
1015 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) => Ok(lt),
1016 Ty::Var(_) => {
1017 self.u.unify(<, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1018 Ok(Ty::int())
1019 }
1020 other => Err(TypeError::TypeMismatch {
1021 at_node: node_id.into(),
1022 expected: "Int or Float".into(),
1023 got: other.pretty(),
1024 context: vec![format!("operator `{op}`")],
1025 }),
1026 }
1027 }
1028 "==" | "!=" => {
1029 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1030 Ok(Ty::bool())
1031 }
1032 "<" | "<=" | ">" | ">=" => {
1033 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1034 let r = self.unfold_record_alias(self.u.resolve(<));
1035 match r {
1036 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) | Ty::Prim(Prim::Str) => Ok(Ty::bool()),
1037 Ty::Var(_) => {
1038 self.u.unify(<, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1039 Ok(Ty::bool())
1040 }
1041 other => Err(TypeError::TypeMismatch {
1042 at_node: node_id.into(),
1043 expected: "Int, Float, or Str".into(),
1044 got: other.pretty(),
1045 context: vec![format!("operator `{op}`")],
1046 }),
1047 }
1048 }
1049 "and" | "or" => {
1050 self.u.unify(<, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1051 self.u.unify(&rt, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
1052 Ok(Ty::bool())
1053 }
1054 other => panic!("unknown binop: {other}"),
1055 }
1056 }
1057
1058 fn check_call(
1059 &mut self,
1060 call_expr: &a::CExpr,
1061 callee: &a::CExpr,
1062 args: &[a::CExpr],
1063 node_id: &str,
1064 locals: &mut IndexMap<String, Ty>,
1065 effs: &mut EffectSet,
1066 ) -> Result<Ty, TypeError> {
1067 let parse_call_ptr = if self.is_module_parse_call(callee) {
1075 Some(call_expr as *const a::CExpr as usize)
1076 } else {
1077 None
1078 };
1079 let callee_ty = self.check_expr(callee, node_id, locals, effs)?;
1080 let resolved = self.u.resolve(&callee_ty);
1081 match resolved {
1082 Ty::Function { params, effects, ret } => {
1083 if params.len() != args.len() {
1084 return Err(TypeError::ArityMismatch {
1085 at_node: node_id.into(),
1086 expected: params.len(),
1087 got: args.len(),
1088 });
1089 }
1090 for (i, (a, p)) in args.iter().zip(params.iter()).enumerate() {
1091 let at = self.check_expr(a, node_id, locals, effs)?;
1092 if let Err(err) = self.unify_with_record_coercion(&at, p) {
1093 return Err(mismatch_err(node_id, err, &self.u, vec![format!("argument {} of call", i + 1)]));
1094 }
1095 }
1096 if let a::CExpr::Var { name: callee_name } = callee {
1103 if let Some(callee_params) = self.fn_params.get(callee_name).cloned() {
1104 for (i, (param, arg)) in callee_params.iter().zip(args.iter()).enumerate() {
1105 if let a::TypeExpr::Refined { binding, predicate, .. } = ¶m.ty {
1106 let outcome = crate::discharge::try_discharge(
1107 predicate, binding, arg);
1108 if let crate::discharge::DischargeOutcome::Refuted { reason } = outcome {
1109 return Err(TypeError::RefinementViolation {
1110 at_node: node_id.into(),
1111 fn_name: callee_name.clone(),
1112 param_index: i,
1113 binding: binding.clone(),
1114 reason,
1115 });
1116 }
1117 }
1118 }
1119 }
1120 }
1121 let resolved_effects = self.u.resolve_effects(&effects);
1126 effs.extend(&resolved_effects);
1127 if let Some(ptr) = parse_call_ptr {
1135 self.pending_parse_calls.push((ptr, (*ret).clone()));
1136 }
1137 Ok(*ret)
1138 }
1139 Ty::Var(_) => {
1140 let mut p_tys = Vec::new();
1142 for a in args { p_tys.push(self.check_expr(a, node_id, locals, effs)?); }
1143 let r = self.u.fresh();
1144 let f = Ty::function(p_tys, EffectSet::empty(), r.clone());
1145 self.u.unify(&callee_ty, &f).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in call".into()]))?;
1146 Ok(r)
1147 }
1148 other => Err(TypeError::TypeMismatch {
1149 at_node: node_id.into(),
1150 expected: "function".into(),
1151 got: other.pretty(),
1152 context: vec!["in call".into()],
1153 }),
1154 }
1155 }
1156
1157 fn check_constructor(
1158 &mut self,
1159 name: &str,
1160 args: &[a::CExpr],
1161 node_id: &str,
1162 locals: &mut IndexMap<String, Ty>,
1163 effs: &mut EffectSet,
1164 ) -> Result<Ty, TypeError> {
1165 let owning = self.type_env.ctor_to_type.get(name).cloned()
1166 .ok_or_else(|| TypeError::UnknownVariant {
1167 at_node: node_id.into(),
1168 constructor: name.to_string(),
1169 })?;
1170 let def = self.type_env.types.get(&owning).cloned()
1171 .expect("ctor_to_type points to a real type");
1172 let variants = match &def.kind {
1173 TypeDefKind::Union(v) => v.clone(),
1174 _ => return Err(TypeError::UnknownVariant {
1175 at_node: node_id.into(),
1176 constructor: name.to_string(),
1177 }),
1178 };
1179 let mut subst = IndexMap::new();
1182 let mut con_args = Vec::with_capacity(def.params.len());
1183 for (i, _p) in def.params.iter().enumerate() {
1184 let fresh = self.u.fresh();
1185 subst.insert(i as u32, fresh.clone());
1186 con_args.push(fresh);
1187 }
1188 let payload = variants.get(name).cloned().flatten();
1189 match (payload, args) {
1190 (None, []) => Ok(Ty::Con(owning, con_args)),
1191 (Some(payload), args) => {
1192 let inst_payload = subst_vars(&payload, &subst, &IndexMap::new());
1193 let arg_count = match &inst_payload {
1194 Ty::Tuple(items) => items.len(),
1195 _ => 1,
1196 };
1197 if arg_count != args.len() {
1198 return Err(TypeError::ArityMismatch {
1199 at_node: node_id.into(),
1200 expected: arg_count,
1201 got: args.len(),
1202 });
1203 }
1204 if args.len() == 1 {
1205 let at = self.check_expr(&args[0], node_id, locals, effs)?;
1206 self.unify_with_record_coercion(&at, &inst_payload).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor `{}`", name)]))?;
1207 } else if let Ty::Tuple(items) = inst_payload {
1208 for (i, (a, t)) in args.iter().zip(items.iter()).enumerate() {
1209 let at = self.check_expr(a, node_id, locals, effs)?;
1210 self.unify_with_record_coercion(&at, t).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor `{}` arg {}", name, i + 1)]))?;
1211 }
1212 }
1213 Ok(Ty::Con(owning, con_args))
1214 }
1215 (None, _) => Err(TypeError::ArityMismatch {
1216 at_node: node_id.into(), expected: 0, got: args.len(),
1217 }),
1218 }
1219 }
1220
1221 fn bind_pattern(
1222 &mut self,
1223 pat: &a::Pattern,
1224 ty: &Ty,
1225 locals: &mut IndexMap<String, Ty>,
1226 node_id: &str,
1227 ) -> Result<(), TypeError> {
1228 match pat {
1229 a::Pattern::PWild => Ok(()),
1230 a::Pattern::PVar { name } => {
1231 locals.insert(name.clone(), ty.clone());
1232 Ok(())
1233 }
1234 a::Pattern::PLiteral { value } => {
1235 let lt = lit_type(value);
1236 self.unify_with_record_coercion(<, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in pattern".into()]))?;
1237 Ok(())
1238 }
1239 a::Pattern::PConstructor { name, args } => {
1240 let owning = self.type_env.ctor_to_type.get(name).cloned()
1242 .ok_or_else(|| TypeError::UnknownVariant {
1243 at_node: node_id.into(), constructor: name.clone(),
1244 })?;
1245 let def = self.type_env.types.get(&owning).cloned().unwrap();
1246 let mut subst = IndexMap::new();
1247 let mut con_args = Vec::new();
1248 for (i, _) in def.params.iter().enumerate() {
1249 let fresh = self.u.fresh();
1250 subst.insert(i as u32, fresh.clone());
1251 con_args.push(fresh);
1252 }
1253 let con_ty = Ty::Con(owning.clone(), con_args);
1254 self.unify_with_record_coercion(&con_ty, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor pattern `{}`", name)]))?;
1255 let payload = match &def.kind {
1256 TypeDefKind::Union(v) => v.get(name).cloned().flatten(),
1257 _ => None,
1258 };
1259 match (payload, args.as_slice()) {
1260 (None, []) => Ok(()),
1261 (Some(payload), args) => {
1262 let inst = subst_vars(&payload, &subst, &IndexMap::new());
1263 if args.len() == 1 {
1264 self.bind_pattern(&args[0], &inst, locals, node_id)?;
1265 } else if let Ty::Tuple(items) = inst {
1266 for (a, t) in args.iter().zip(items.iter()) {
1267 self.bind_pattern(a, t, locals, node_id)?;
1268 }
1269 }
1270 Ok(())
1271 }
1272 (None, _) => Err(TypeError::ArityMismatch {
1273 at_node: node_id.into(), expected: 0, got: args.len(),
1274 }),
1275 }
1276 }
1277 a::Pattern::PRecord { fields } => {
1278 let resolved = self.unfold_record_alias(self.u.resolve(ty));
1283 let rec = match resolved {
1284 Ty::Record(r) => r,
1285 _ => return Err(TypeError::TypeMismatch {
1286 at_node: node_id.into(),
1287 expected: "record".into(),
1288 got: ty.pretty(),
1289 context: vec!["in record pattern".into()],
1290 }),
1291 };
1292 for f in fields {
1293 let ft = rec.get(&f.name).cloned()
1294 .ok_or_else(|| TypeError::UnknownField {
1295 at_node: node_id.into(),
1296 record_type: Ty::Record(rec.clone()).pretty(),
1297 field: f.name.clone(),
1298 })?;
1299 self.bind_pattern(&f.pattern, &ft, locals, node_id)?;
1300 }
1301 Ok(())
1302 }
1303 a::Pattern::PTuple { items } => {
1304 if items.is_empty() {
1306 return self.unify_with_record_coercion(&Ty::Unit, ty)
1307 .map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in unit pattern".into()]));
1308 }
1309 let resolved = self.u.resolve(ty);
1310 let tup = match resolved {
1311 Ty::Tuple(t) => t,
1312 Ty::Var(_) => {
1313 let fresh: Vec<Ty> = items.iter().map(|_| self.u.fresh()).collect();
1314 let tup_ty = Ty::Tuple(fresh.clone());
1315 self.unify_with_record_coercion(&tup_ty, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in tuple pattern".into()]))?;
1316 fresh
1317 }
1318 other => {
1319 return Err(TypeError::TypeMismatch {
1320 at_node: node_id.into(),
1321 expected: "tuple".into(),
1322 got: other.pretty(),
1323 context: vec!["in tuple pattern".into()],
1324 });
1325 }
1326 };
1327 if tup.len() != items.len() {
1328 return Err(TypeError::ArityMismatch {
1329 at_node: node_id.into(), expected: tup.len(), got: items.len(),
1330 });
1331 }
1332 for (p, t) in items.iter().zip(tup.iter()) {
1333 self.bind_pattern(p, t, locals, node_id)?;
1334 }
1335 Ok(())
1336 }
1337 }
1338 }
1339}
1340
1341fn lit_type(l: &a::CLit) -> Ty {
1342 match l {
1343 a::CLit::Int { .. } => Ty::int(),
1344 a::CLit::Float { .. } => Ty::float(),
1345 a::CLit::Str { .. } => Ty::str(),
1346 a::CLit::Bytes { .. } => Ty::bytes(),
1347 a::CLit::Bool { .. } => Ty::bool(),
1348 a::CLit::Unit => Ty::Unit,
1349 }
1350}
1351
1352fn instantiate(s: &Scheme, u: &mut Unifier) -> Ty {
1353 let mut ty_subst = IndexMap::new();
1354 for v in &s.vars { ty_subst.insert(*v, u.fresh()); }
1355 let mut eff_subst = IndexMap::new();
1356 for v in &s.eff_vars { eff_subst.insert(*v, u.fresh_eff_id()); }
1357 subst_vars(&s.ty, &ty_subst, &eff_subst)
1358}
1359
1360fn subst_vars(
1361 t: &Ty,
1362 subst: &IndexMap<TyVarId, Ty>,
1363 eff_subst: &IndexMap<u32, u32>,
1364) -> Ty {
1365 match t {
1366 Ty::Var(v) => subst.get(v).cloned().unwrap_or_else(|| Ty::Var(*v)),
1367 Ty::Prim(_) | Ty::Unit | Ty::Never => t.clone(),
1368 Ty::List(inner) => Ty::List(Box::new(subst_vars(inner, subst, eff_subst))),
1369 Ty::Tuple(items) => Ty::Tuple(items.iter().map(|t| subst_vars(t, subst, eff_subst)).collect()),
1370 Ty::Record(fs) => {
1371 let mut out = IndexMap::new();
1372 for (k, v) in fs { out.insert(k.clone(), subst_vars(v, subst, eff_subst)); }
1373 Ty::Record(out)
1374 }
1375 Ty::Con(n, args) => Ty::Con(n.clone(),
1376 args.iter().map(|t| subst_vars(t, subst, eff_subst)).collect()),
1377 Ty::Function { params, effects, ret } => {
1378 let new_effects = EffectSet {
1381 concrete: effects.concrete.clone(),
1382 var: effects.var.and_then(|v| eff_subst.get(&v).copied()).or(effects.var),
1383 };
1384 Ty::Function {
1385 params: params.iter().map(|t| subst_vars(t, subst, eff_subst)).collect(),
1386 effects: new_effects,
1387 ret: Box::new(subst_vars(ret, subst, eff_subst)),
1388 }
1389 }
1390 }
1391}
1392
1393fn mismatch_err(node_id: &str, e: UnifyError, u: &Unifier, context: Vec<String>) -> TypeError {
1394 match e {
1395 UnifyError::Mismatch { a, b } => TypeError::TypeMismatch {
1396 at_node: node_id.into(),
1397 expected: u.resolve(&b).pretty(),
1398 got: u.resolve(&a).pretty(),
1399 context,
1400 },
1401 UnifyError::Infinite { .. } => TypeError::InfiniteType { at_node: node_id.into() },
1402 UnifyError::EffectMismatch { a, b } => {
1403 let render = |e: &EffectSet| -> String {
1408 let mut parts: Vec<String> = e.concrete.iter()
1409 .map(crate::types::EffectKind::pretty).collect();
1410 if let Some(v) = e.var { parts.push(format!("?e{}", v)); }
1411 if parts.is_empty() { "[]".into() } else { format!("[{}]", parts.join(", ")) }
1412 };
1413 TypeError::EffectRowMismatch {
1414 at_node: node_id.into(),
1415 expected: render(&b),
1416 got: render(&a),
1417 context,
1418 }
1419 }
1420 }
1421}