1use crate::builtins::{module_for_import, module_scope};
5use crate::env::{TypeDefKind, TypeEnv, ty_from_canon};
6use crate::error::TypeError;
7use crate::types::*;
8use crate::unifier::{UnifyError, Unifier};
9use indexmap::IndexMap;
10use lex_ast as a;
11use std::collections::HashMap;
12
13pub struct ProgramTypes {
15 pub fn_signatures: IndexMap<String, Scheme>,
16 pub type_env: TypeEnv,
17 pub parse_required_fields: HashMap<usize, Vec<String>>,
26}
27
28pub fn check_program(stages: &[a::Stage]) -> Result<ProgramTypes, Vec<TypeError>> {
29 let mut tcx = Checker::new();
30 let mut errors = Vec::new();
31
32 for stage in stages {
34 if let a::Stage::Import(i) = stage {
35 if let Some(mod_name) = module_for_import(&i.reference) {
36 if let Some(ty) = module_scope(mod_name, &tcx.type_env) {
37 tcx.globals.insert(i.alias.clone(), Scheme {
38 vars: collect_vars(&ty),
42 eff_vars: collect_eff_vars(&ty),
43 ty,
44 });
45 tcx.module_aliases.insert(i.alias.clone(), mod_name.to_string());
46 }
47 }
48 }
49 }
50
51 for stage in stages {
53 if let a::Stage::TypeDecl(td) = stage {
54 if let Err(e) = tcx.type_env.add_user_type(&td.name, td.clone()) {
55 errors.push(TypeError::RecursiveTypeWithoutConstructor {
56 at_node: "n_0".into(),
57 name: e,
58 });
59 }
60 }
61 }
62
63 for stage in stages {
65 if let a::Stage::FnDecl(fd) = stage {
66 let scheme = function_scheme(fd);
67 tcx.globals.insert(fd.name.clone(), scheme);
68 }
69 }
70
71 let mut signatures = IndexMap::new();
73 for stage in stages {
74 if let a::Stage::FnDecl(fd) = stage {
75 match tcx.check_fn(fd) {
76 Ok(scheme) => { signatures.insert(fd.name.clone(), scheme); }
77 Err(es) => errors.extend(es),
78 }
79 }
80 }
81
82 if errors.is_empty() {
83 let mut parse_required_fields = HashMap::new();
89 for (call_ptr, ret_ty) in &tcx.pending_parse_calls {
90 if let Some(fields) = extract_record_fields_from_result(&tcx.u, &tcx.type_env, ret_ty) {
91 parse_required_fields.insert(*call_ptr, fields);
92 }
93 }
94 Ok(ProgramTypes {
95 fn_signatures: signatures,
96 type_env: tcx.type_env,
97 parse_required_fields,
98 })
99 } else {
100 Err(errors)
101 }
102}
103
104pub fn check_and_rewrite_program(
110 stages: &mut [a::Stage],
111) -> Result<ProgramTypes, Vec<TypeError>> {
112 let pt = check_program(&*stages)?;
117 if !pt.parse_required_fields.is_empty() {
118 rewrite_parse_calls(stages, &pt.parse_required_fields);
119 }
120 Ok(pt)
121}
122
123fn rewrite_parse_calls(stages: &mut [a::Stage], required: &HashMap<usize, Vec<String>>) {
137 for stage in stages.iter_mut() {
138 if let a::Stage::FnDecl(fd) = stage {
139 rewrite_in_expr(&mut fd.body, required);
140 }
141 }
142}
143
144fn rewrite_in_expr(expr: &mut a::CExpr, required: &HashMap<usize, Vec<String>>) {
145 let ptr = expr as *const a::CExpr as usize;
146 let do_rewrite = required.get(&ptr).cloned();
147 match expr {
153 a::CExpr::Call { callee, args } => {
154 rewrite_in_expr(callee, required);
155 for a in args.iter_mut() { rewrite_in_expr(a, required); }
156 }
157 a::CExpr::Let { value, body, .. } => {
158 rewrite_in_expr(value, required);
159 rewrite_in_expr(body, required);
160 }
161 a::CExpr::Match { scrutinee, arms } => {
162 rewrite_in_expr(scrutinee, required);
163 for arm in arms.iter_mut() { rewrite_in_expr(&mut arm.body, required); }
164 }
165 a::CExpr::Block { statements, result } => {
166 for s in statements.iter_mut() { rewrite_in_expr(s, required); }
167 rewrite_in_expr(result, required);
168 }
169 a::CExpr::Constructor { args, .. } => {
170 for a in args.iter_mut() { rewrite_in_expr(a, required); }
171 }
172 a::CExpr::RecordLit { fields } => {
173 for f in fields.iter_mut() { rewrite_in_expr(&mut f.value, required); }
174 }
175 a::CExpr::TupleLit { items } | a::CExpr::ListLit { items } => {
176 for it in items.iter_mut() { rewrite_in_expr(it, required); }
177 }
178 a::CExpr::FieldAccess { value, .. } => rewrite_in_expr(value, required),
179 a::CExpr::Lambda { body, .. } => rewrite_in_expr(body, required),
180 a::CExpr::BinOp { lhs, rhs, .. } => {
181 rewrite_in_expr(lhs, required);
182 rewrite_in_expr(rhs, required);
183 }
184 a::CExpr::UnaryOp { expr, .. } => rewrite_in_expr(expr, required),
185 a::CExpr::Return { value } => rewrite_in_expr(value, required),
186 a::CExpr::Literal { .. } | a::CExpr::Var { .. } => {}
187 }
188 if let Some(fields) = do_rewrite {
189 match expr {
190 a::CExpr::Call { callee, args } => {
191 if let a::CExpr::FieldAccess { field, .. } = callee.as_mut() {
192 debug_assert_eq!(field, "parse",
193 "rewrite_in_expr: only `.parse` calls should be in the table");
194 *field = "parse_strict".to_string();
195 }
196 args.push(a::CExpr::ListLit {
197 items: fields.into_iter()
198 .map(|f| a::CExpr::Literal {
199 value: a::CLit::Str { value: f },
200 })
201 .collect(),
202 });
203 }
204 _ => unreachable!("rewrite table key must point to a Call expression"),
205 }
206 }
207}
208
209fn extract_record_fields_from_result(
214 u: &Unifier,
215 env: &TypeEnv,
216 ty: &Ty,
217) -> Option<Vec<String>> {
218 let resolved = u.resolve(ty);
219 let Ty::Con(ref name, ref args) = resolved else { return None; };
220 if name != "Result" || args.len() != 2 { return None; }
221 let ok_ty = u.resolve(&args[0]);
222 let unfolded = unfold_record_alias_static(env, ok_ty);
223 if let Ty::Record(fields) = unfolded {
224 Some(fields.keys().cloned().collect())
225 } else {
226 None
227 }
228}
229
230fn unfold_record_alias_static(env: &TypeEnv, ty: Ty) -> Ty {
235 if let Ty::Con(ref n, _) = ty {
236 if let Some(td) = env.types.get(n) {
237 if let TypeDefKind::Alias(inner @ Ty::Record(_)) = &td.kind {
238 return inner.clone();
239 }
240 }
241 }
242 ty
243}
244
245fn collect_vars(t: &Ty) -> Vec<TyVarId> {
246 let mut out = Vec::new();
247 fn walk(t: &Ty, out: &mut Vec<TyVarId>) {
248 match t {
249 Ty::Var(v) => { if !out.contains(v) { out.push(*v); } }
250 Ty::Prim(_) | Ty::Unit | Ty::Never => {}
251 Ty::List(inner) => walk(inner, out),
252 Ty::Tuple(items) => for it in items { walk(it, out); },
253 Ty::Record(fs) => for v in fs.values() { walk(v, out); },
254 Ty::Con(_, args) => for a in args { walk(a, out); },
255 Ty::Function { params, ret, .. } => {
256 for p in params { walk(p, out); }
257 walk(ret, out);
258 }
259 }
260 }
261 walk(t, &mut out);
262 out
263}
264
265fn collect_eff_vars(t: &Ty) -> Vec<u32> {
269 let mut out = Vec::new();
270 fn walk(t: &Ty, out: &mut Vec<u32>) {
271 match t {
272 Ty::Var(_) | Ty::Prim(_) | Ty::Unit | Ty::Never => {}
273 Ty::List(inner) => walk(inner, out),
274 Ty::Tuple(items) => for it in items { walk(it, out); },
275 Ty::Record(fs) => for v in fs.values() { walk(v, out); },
276 Ty::Con(_, args) => for a in args { walk(a, out); },
277 Ty::Function { params, effects, ret } => {
278 if let Some(v) = effects.var {
279 if !out.contains(&v) { out.push(v); }
280 }
281 for p in params { walk(p, out); }
282 walk(ret, out);
283 }
284 }
285 }
286 walk(t, &mut out);
287 out
288}
289
290fn function_scheme(fd: &a::FnDecl) -> Scheme {
291 let params: Vec<Ty> = fd.params.iter().map(|p| ty_from_canon(&p.ty, &fd.type_params)).collect();
293 let ret = ty_from_canon(&fd.return_type, &fd.type_params);
294 let effects = EffectSet {
295 concrete: {
296 let mut s = std::collections::BTreeSet::new();
297 for e in &fd.effects { s.insert(e.name.clone()); }
298 s
299 },
300 var: None,
301 };
302 let ty = Ty::Function { params, effects, ret: Box::new(ret) };
303 let vars: Vec<TyVarId> = (0..fd.type_params.len() as u32).collect();
304 Scheme { vars, eff_vars: Vec::new(), ty }
308}
309
310struct Checker {
311 u: Unifier,
312 type_env: TypeEnv,
313 globals: IndexMap<String, Scheme>,
314 module_aliases: IndexMap<String, String>,
318 pending_parse_calls: Vec<(usize, Ty)>,
326}
327
328impl Checker {
329 fn new() -> Self {
330 Self {
331 u: Unifier::new(),
332 type_env: TypeEnv::new_with_builtins(),
333 globals: IndexMap::new(),
334 module_aliases: IndexMap::new(),
335 pending_parse_calls: Vec::new(),
336 }
337 }
338
339 fn unfold_record_alias(&self, ty: Ty) -> Ty {
343 if let Ty::Con(ref n, _) = ty {
344 if let Some(td) = self.type_env.types.get(n) {
345 if let TypeDefKind::Alias(inner @ Ty::Record(_)) = &td.kind {
346 return inner.clone();
347 }
348 }
349 }
350 ty
351 }
352
353 fn is_module_parse_call(&self, callee: &a::CExpr) -> bool {
358 if let a::CExpr::FieldAccess { value, field } = callee {
359 if field != "parse" { return false; }
360 if let a::CExpr::Var { name } = value.as_ref() {
361 if let Some(module) = self.module_aliases.get(name) {
362 return matches!(module.as_str(), "json" | "toml" | "yaml");
363 }
364 }
365 }
366 false
367 }
368
369 fn unify_with_record_coercion(&mut self, a: &Ty, b: &Ty) -> Result<(), UnifyError> {
381 let a = self.u.resolve(a);
382 let b = self.u.resolve(b);
383 self.unify_coerce_inner(a, b)
384 }
385
386 fn unify_coerce_inner(&mut self, a: Ty, b: Ty) -> Result<(), UnifyError> {
387 let (a, b) = match (&a, &b) {
389 (Ty::Record(_), Ty::Con(_, _)) => (a, self.unfold_record_alias(b.clone())),
390 (Ty::Con(_, _), Ty::Record(_)) => (self.unfold_record_alias(a.clone()), b),
391 _ => (a, b),
392 };
393
394 match (&a, &b) {
395 (Ty::Record(fa), Ty::Record(fb)) => {
396 if fa.len() != fb.len() {
397 return Err(UnifyError::Mismatch { a: a.clone(), b: b.clone() });
398 }
399 for (k, va) in fa.clone() {
400 match fb.get(&k) {
401 Some(vb) => self.unify_coerce_inner(va, vb.clone())?,
402 None => return Err(UnifyError::Mismatch { a: a.clone(), b: b.clone() }),
403 }
404 }
405 Ok(())
406 }
407 (Ty::List(ta), Ty::List(tb)) => {
408 self.unify_coerce_inner((**ta).clone(), (**tb).clone())
409 }
410 (Ty::Tuple(xs), Ty::Tuple(ys)) if xs.len() == ys.len() => {
411 for (x, y) in xs.clone().into_iter().zip(ys.clone()) {
412 self.unify_coerce_inner(x, y)?;
413 }
414 Ok(())
415 }
416 _ => self.u.unify(&a, &b),
417 }
418 }
419
420 fn check_fn(&mut self, fd: &a::FnDecl) -> Result<Scheme, Vec<TypeError>> {
421 let scheme = function_scheme(fd);
423 let (param_tys, declared_effects, ret_ty) = match instantiate(&scheme, &mut self.u) {
424 Ty::Function { params, effects, ret } => (params, effects, *ret),
425 _ => unreachable!(),
426 };
427
428 let mut locals: IndexMap<String, Ty> = IndexMap::new();
429 for (p, t) in fd.params.iter().zip(param_tys.iter()) {
430 locals.insert(p.name.clone(), t.clone());
431 }
432
433 let mut inferred_effects = EffectSet::empty();
434 let body_ty = self.check_expr(&fd.body, "n_0", &mut locals, &mut inferred_effects)
435 .map_err(|e| vec![e])?;
436
437 if let Err(e) = self.unify_with_record_coercion(&body_ty, &ret_ty) {
443 return Err(vec![mismatch_err("n_0", e, &self.u, vec![format!("in function `{}`", fd.name)])]);
444 }
445
446 if !inferred_effects.is_subset(&declared_effects) {
447 for e in inferred_effects.concrete.iter() {
449 if !declared_effects.concrete.contains(e) {
450 return Err(vec![TypeError::EffectNotDeclared {
451 at_node: "n_0".into(),
452 effect: e.clone(),
453 }]);
454 }
455 }
456 }
457
458 Ok(scheme)
459 }
460
461 fn check_expr(
462 &mut self,
463 e: &a::CExpr,
464 node_id: &str,
465 locals: &mut IndexMap<String, Ty>,
466 effs: &mut EffectSet,
467 ) -> Result<Ty, TypeError> {
468 match e {
469 a::CExpr::Literal { value } => Ok(lit_type(value)),
470 a::CExpr::Var { name } => {
471 if let Some(t) = locals.get(name) {
472 return Ok(t.clone());
473 }
474 if let Some(scheme) = self.globals.get(name).cloned() {
475 return Ok(instantiate(&scheme, &mut self.u));
476 }
477 Err(TypeError::UnknownIdentifier { at_node: node_id.into(), name: name.clone() })
478 }
479 a::CExpr::Constructor { name, args } => self.check_constructor(name, args, node_id, locals, effs),
480 a::CExpr::Call { callee, args } => self.check_call(e, callee, args, node_id, locals, effs),
481 a::CExpr::Let { name, ty, value, body } => {
482 let v_ty = self.check_expr(value, node_id, locals, effs)?;
483 if let Some(declared) = ty {
484 let d = ty_from_canon(declared, &[]);
485 if let Err(err) = self.unify_with_record_coercion(&v_ty, &d) {
486 return Err(mismatch_err(node_id, err, &self.u, vec![format!("in let `{}`", name)]));
487 }
488 }
489 let prev = locals.insert(name.clone(), v_ty);
490 let body_ty = self.check_expr(body, node_id, locals, effs)?;
491 match prev {
492 Some(p) => { locals.insert(name.clone(), p); }
493 None => { locals.shift_remove(name); }
494 }
495 Ok(body_ty)
496 }
497 a::CExpr::Match { scrutinee, arms } => {
498 let scrut_ty = self.check_expr(scrutinee, node_id, locals, effs)?;
499 if arms.is_empty() {
500 return Err(TypeError::NonExhaustiveMatch {
501 at_node: node_id.into(), missing: vec!["_".into()]
502 });
503 }
504 let result_ty = self.u.fresh();
505 for arm in arms {
506 let mut arm_locals = locals.clone();
507 self.bind_pattern(&arm.pattern, &scrut_ty, &mut arm_locals, node_id)?;
508 let arm_ty = self.check_expr(&arm.body, node_id, &mut arm_locals, effs)?;
509 if let Err(err) = self.unify_with_record_coercion(&arm_ty, &result_ty) {
510 return Err(mismatch_err(node_id, err, &self.u, vec!["in match arm".into()]));
511 }
512 }
513 Ok(result_ty)
514 }
515 a::CExpr::Block { statements, result } => {
516 for s in statements {
517 self.check_expr(s, node_id, locals, effs)?;
518 }
519 self.check_expr(result, node_id, locals, effs)
520 }
521 a::CExpr::RecordLit { fields } => {
522 let mut tys = IndexMap::new();
523 for f in fields {
524 if tys.contains_key(&f.name) {
525 return Err(TypeError::DuplicateField {
526 at_node: node_id.into(), field: f.name.clone()
527 });
528 }
529 let ft = self.check_expr(&f.value, node_id, locals, effs)?;
530 tys.insert(f.name.clone(), ft);
531 }
532 Ok(Ty::Record(tys))
533 }
534 a::CExpr::TupleLit { items } => {
535 let mut ts = Vec::new();
536 for it in items { ts.push(self.check_expr(it, node_id, locals, effs)?); }
537 Ok(Ty::Tuple(ts))
538 }
539 a::CExpr::ListLit { items } => {
540 let elem = self.u.fresh();
541 for it in items {
542 let t = self.check_expr(it, node_id, locals, effs)?;
543 if let Err(err) = self.unify_with_record_coercion(&t, &elem) {
544 return Err(mismatch_err(node_id, err, &self.u, vec!["in list literal".into()]));
545 }
546 }
547 Ok(Ty::List(Box::new(elem)))
548 }
549 a::CExpr::FieldAccess { value, field } => {
550 let vt = self.check_expr(value, node_id, locals, effs)?;
551 let resolved = self.u.resolve(&vt);
552 let resolved = match resolved {
554 Ty::Con(ref n, _) => match self.type_env.types.get(n) {
555 Some(td) => match &td.kind {
556 TypeDefKind::Alias(inner @ Ty::Record(_)) => inner.clone(),
557 _ => resolved,
558 },
559 None => resolved,
560 },
561 other => other,
562 };
563 match resolved {
564 Ty::Record(fields) => fields.get(field).cloned()
565 .ok_or_else(|| TypeError::UnknownField {
566 at_node: node_id.into(),
567 record_type: Ty::Record(fields.clone()).pretty(),
568 field: field.clone(),
569 }),
570 other => Err(TypeError::TypeMismatch {
571 at_node: node_id.into(),
572 expected: "record".into(),
573 got: other.pretty(),
574 context: vec![format!("field access `.{}`", field)],
575 }),
576 }
577 }
578 a::CExpr::Lambda { params, return_type, effects: l_effects, body } => {
579 let param_tys: Vec<Ty> = params.iter().map(|p| ty_from_canon(&p.ty, &[])).collect();
580 let ret_ty = ty_from_canon(return_type, &[]);
581 let declared = EffectSet {
582 concrete: {
583 let mut s = std::collections::BTreeSet::new();
584 for e in l_effects { s.insert(e.name.clone()); }
585 s
586 },
587 var: None,
588 };
589 let mut inner_locals = locals.clone();
590 for (p, t) in params.iter().zip(param_tys.iter()) {
591 inner_locals.insert(p.name.clone(), t.clone());
592 }
593 let mut inner_effs = EffectSet::empty();
594 let body_ty = self.check_expr(body, node_id, &mut inner_locals, &mut inner_effs)?;
595 if let Err(err) = self.unify_with_record_coercion(&body_ty, &ret_ty) {
596 return Err(mismatch_err(node_id, err, &self.u, vec!["in lambda body".into()]));
597 }
598 if !inner_effs.is_subset(&declared) {
599 for e in inner_effs.concrete.iter() {
600 if !declared.concrete.contains(e) {
601 return Err(TypeError::EffectNotDeclared {
602 at_node: node_id.into(),
603 effect: e.clone(),
604 });
605 }
606 }
607 }
608 Ok(Ty::function(param_tys, declared, ret_ty))
609 }
610 a::CExpr::BinOp { op, lhs, rhs } => self.check_binop(op, lhs, rhs, node_id, locals, effs),
611 a::CExpr::UnaryOp { op, expr } => {
612 let t = self.check_expr(expr, node_id, locals, effs)?;
613 match op.as_str() {
614 "-" => {
615 let r = self.u.resolve(&t);
617 match r {
618 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) => Ok(t),
619 Ty::Var(_) => {
620 self.u.unify(&t, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![]))?;
622 Ok(Ty::int())
623 }
624 other => Err(TypeError::TypeMismatch {
625 at_node: node_id.into(),
626 expected: "Int or Float".into(),
627 got: other.pretty(),
628 context: vec!["unary `-`".into()],
629 }),
630 }
631 }
632 "not" => {
633 self.u.unify(&t, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["unary `not`".into()]))?;
634 Ok(Ty::bool())
635 }
636 other => panic!("unknown unary op: {other}"),
637 }
638 }
639 a::CExpr::Return { value } => {
640 self.check_expr(value, node_id, locals, effs)?;
643 Ok(Ty::Never)
644 }
645 }
646 }
647
648 fn check_binop(
649 &mut self,
650 op: &str,
651 lhs: &a::CExpr,
652 rhs: &a::CExpr,
653 node_id: &str,
654 locals: &mut IndexMap<String, Ty>,
655 effs: &mut EffectSet,
656 ) -> Result<Ty, TypeError> {
657 let lt = self.check_expr(lhs, node_id, locals, effs)?;
658 let rt = self.check_expr(rhs, node_id, locals, effs)?;
659 match op {
660 "+" | "-" | "*" | "/" | "%" => {
661 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
662 let r = self.u.resolve(<);
663 match r {
664 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) => Ok(lt),
665 Ty::Var(_) => {
666 self.u.unify(<, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
667 Ok(Ty::int())
668 }
669 other => Err(TypeError::TypeMismatch {
670 at_node: node_id.into(),
671 expected: "Int or Float".into(),
672 got: other.pretty(),
673 context: vec![format!("operator `{op}`")],
674 }),
675 }
676 }
677 "==" | "!=" => {
678 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
679 Ok(Ty::bool())
680 }
681 "<" | "<=" | ">" | ">=" => {
682 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
683 let r = self.u.resolve(<);
684 match r {
685 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) | Ty::Prim(Prim::Str) => Ok(Ty::bool()),
686 Ty::Var(_) => {
687 self.u.unify(<, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
688 Ok(Ty::bool())
689 }
690 other => Err(TypeError::TypeMismatch {
691 at_node: node_id.into(),
692 expected: "Int, Float, or Str".into(),
693 got: other.pretty(),
694 context: vec![format!("operator `{op}`")],
695 }),
696 }
697 }
698 "and" | "or" => {
699 self.u.unify(<, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
700 self.u.unify(&rt, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
701 Ok(Ty::bool())
702 }
703 other => panic!("unknown binop: {other}"),
704 }
705 }
706
707 fn check_call(
708 &mut self,
709 call_expr: &a::CExpr,
710 callee: &a::CExpr,
711 args: &[a::CExpr],
712 node_id: &str,
713 locals: &mut IndexMap<String, Ty>,
714 effs: &mut EffectSet,
715 ) -> Result<Ty, TypeError> {
716 let parse_call_ptr = if self.is_module_parse_call(callee) {
724 Some(call_expr as *const a::CExpr as usize)
725 } else {
726 None
727 };
728 let callee_ty = self.check_expr(callee, node_id, locals, effs)?;
729 let resolved = self.u.resolve(&callee_ty);
730 match resolved {
731 Ty::Function { params, effects, ret } => {
732 if params.len() != args.len() {
733 return Err(TypeError::ArityMismatch {
734 at_node: node_id.into(),
735 expected: params.len(),
736 got: args.len(),
737 });
738 }
739 for (i, (a, p)) in args.iter().zip(params.iter()).enumerate() {
740 let at = self.check_expr(a, node_id, locals, effs)?;
741 if let Err(err) = self.unify_with_record_coercion(&at, p) {
742 return Err(mismatch_err(node_id, err, &self.u, vec![format!("argument {} of call", i + 1)]));
743 }
744 }
745 let resolved_effects = self.u.resolve_effects(&effects);
750 effs.extend(&resolved_effects);
751 if let Some(ptr) = parse_call_ptr {
759 self.pending_parse_calls.push((ptr, (*ret).clone()));
760 }
761 Ok(*ret)
762 }
763 Ty::Var(_) => {
764 let mut p_tys = Vec::new();
766 for a in args { p_tys.push(self.check_expr(a, node_id, locals, effs)?); }
767 let r = self.u.fresh();
768 let f = Ty::function(p_tys, EffectSet::empty(), r.clone());
769 self.u.unify(&callee_ty, &f).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in call".into()]))?;
770 Ok(r)
771 }
772 other => Err(TypeError::TypeMismatch {
773 at_node: node_id.into(),
774 expected: "function".into(),
775 got: other.pretty(),
776 context: vec!["in call".into()],
777 }),
778 }
779 }
780
781 fn check_constructor(
782 &mut self,
783 name: &str,
784 args: &[a::CExpr],
785 node_id: &str,
786 locals: &mut IndexMap<String, Ty>,
787 effs: &mut EffectSet,
788 ) -> Result<Ty, TypeError> {
789 let owning = self.type_env.ctor_to_type.get(name).cloned()
790 .ok_or_else(|| TypeError::UnknownVariant {
791 at_node: node_id.into(),
792 constructor: name.to_string(),
793 })?;
794 let def = self.type_env.types.get(&owning).cloned()
795 .expect("ctor_to_type points to a real type");
796 let variants = match &def.kind {
797 TypeDefKind::Union(v) => v.clone(),
798 _ => return Err(TypeError::UnknownVariant {
799 at_node: node_id.into(),
800 constructor: name.to_string(),
801 }),
802 };
803 let mut subst = IndexMap::new();
806 let mut con_args = Vec::with_capacity(def.params.len());
807 for (i, _p) in def.params.iter().enumerate() {
808 let fresh = self.u.fresh();
809 subst.insert(i as u32, fresh.clone());
810 con_args.push(fresh);
811 }
812 let payload = variants.get(name).cloned().flatten();
813 match (payload, args) {
814 (None, []) => Ok(Ty::Con(owning, con_args)),
815 (Some(payload), args) => {
816 let inst_payload = subst_vars(&payload, &subst, &IndexMap::new());
817 let arg_count = match &inst_payload {
818 Ty::Tuple(items) => items.len(),
819 _ => 1,
820 };
821 if arg_count != args.len() {
822 return Err(TypeError::ArityMismatch {
823 at_node: node_id.into(),
824 expected: arg_count,
825 got: args.len(),
826 });
827 }
828 if args.len() == 1 {
829 let at = self.check_expr(&args[0], node_id, locals, effs)?;
830 self.unify_with_record_coercion(&at, &inst_payload).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor `{}`", name)]))?;
831 } else if let Ty::Tuple(items) = inst_payload {
832 for (i, (a, t)) in args.iter().zip(items.iter()).enumerate() {
833 let at = self.check_expr(a, node_id, locals, effs)?;
834 self.unify_with_record_coercion(&at, t).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor `{}` arg {}", name, i + 1)]))?;
835 }
836 }
837 Ok(Ty::Con(owning, con_args))
838 }
839 (None, _) => Err(TypeError::ArityMismatch {
840 at_node: node_id.into(), expected: 0, got: args.len(),
841 }),
842 }
843 }
844
845 fn bind_pattern(
846 &mut self,
847 pat: &a::Pattern,
848 ty: &Ty,
849 locals: &mut IndexMap<String, Ty>,
850 node_id: &str,
851 ) -> Result<(), TypeError> {
852 match pat {
853 a::Pattern::PWild => Ok(()),
854 a::Pattern::PVar { name } => {
855 locals.insert(name.clone(), ty.clone());
856 Ok(())
857 }
858 a::Pattern::PLiteral { value } => {
859 let lt = lit_type(value);
860 self.unify_with_record_coercion(<, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in pattern".into()]))?;
861 Ok(())
862 }
863 a::Pattern::PConstructor { name, args } => {
864 let owning = self.type_env.ctor_to_type.get(name).cloned()
866 .ok_or_else(|| TypeError::UnknownVariant {
867 at_node: node_id.into(), constructor: name.clone(),
868 })?;
869 let def = self.type_env.types.get(&owning).cloned().unwrap();
870 let mut subst = IndexMap::new();
871 let mut con_args = Vec::new();
872 for (i, _) in def.params.iter().enumerate() {
873 let fresh = self.u.fresh();
874 subst.insert(i as u32, fresh.clone());
875 con_args.push(fresh);
876 }
877 let con_ty = Ty::Con(owning.clone(), con_args);
878 self.unify_with_record_coercion(&con_ty, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor pattern `{}`", name)]))?;
879 let payload = match &def.kind {
880 TypeDefKind::Union(v) => v.get(name).cloned().flatten(),
881 _ => None,
882 };
883 match (payload, args.as_slice()) {
884 (None, []) => Ok(()),
885 (Some(payload), args) => {
886 let inst = subst_vars(&payload, &subst, &IndexMap::new());
887 if args.len() == 1 {
888 self.bind_pattern(&args[0], &inst, locals, node_id)?;
889 } else if let Ty::Tuple(items) = inst {
890 for (a, t) in args.iter().zip(items.iter()) {
891 self.bind_pattern(a, t, locals, node_id)?;
892 }
893 }
894 Ok(())
895 }
896 (None, _) => Err(TypeError::ArityMismatch {
897 at_node: node_id.into(), expected: 0, got: args.len(),
898 }),
899 }
900 }
901 a::Pattern::PRecord { fields } => {
902 let resolved = self.unfold_record_alias(self.u.resolve(ty));
907 let rec = match resolved {
908 Ty::Record(r) => r,
909 _ => return Err(TypeError::TypeMismatch {
910 at_node: node_id.into(),
911 expected: "record".into(),
912 got: ty.pretty(),
913 context: vec!["in record pattern".into()],
914 }),
915 };
916 for f in fields {
917 let ft = rec.get(&f.name).cloned()
918 .ok_or_else(|| TypeError::UnknownField {
919 at_node: node_id.into(),
920 record_type: Ty::Record(rec.clone()).pretty(),
921 field: f.name.clone(),
922 })?;
923 self.bind_pattern(&f.pattern, &ft, locals, node_id)?;
924 }
925 Ok(())
926 }
927 a::Pattern::PTuple { items } => {
928 let resolved = self.u.resolve(ty);
929 let tup = match resolved {
930 Ty::Tuple(t) => t,
931 Ty::Var(_) => {
932 let fresh: Vec<Ty> = items.iter().map(|_| self.u.fresh()).collect();
933 let tup_ty = Ty::Tuple(fresh.clone());
934 self.unify_with_record_coercion(&tup_ty, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in tuple pattern".into()]))?;
935 fresh
936 }
937 other => return Err(TypeError::TypeMismatch {
938 at_node: node_id.into(),
939 expected: "tuple".into(),
940 got: other.pretty(),
941 context: vec!["in tuple pattern".into()],
942 }),
943 };
944 if tup.len() != items.len() {
945 return Err(TypeError::ArityMismatch {
946 at_node: node_id.into(), expected: tup.len(), got: items.len(),
947 });
948 }
949 for (p, t) in items.iter().zip(tup.iter()) {
950 self.bind_pattern(p, t, locals, node_id)?;
951 }
952 Ok(())
953 }
954 }
955 }
956}
957
958fn lit_type(l: &a::CLit) -> Ty {
959 match l {
960 a::CLit::Int { .. } => Ty::int(),
961 a::CLit::Float { .. } => Ty::float(),
962 a::CLit::Str { .. } => Ty::str(),
963 a::CLit::Bytes { .. } => Ty::bytes(),
964 a::CLit::Bool { .. } => Ty::bool(),
965 a::CLit::Unit => Ty::Unit,
966 }
967}
968
969fn instantiate(s: &Scheme, u: &mut Unifier) -> Ty {
970 let mut ty_subst = IndexMap::new();
971 for v in &s.vars { ty_subst.insert(*v, u.fresh()); }
972 let mut eff_subst = IndexMap::new();
973 for v in &s.eff_vars { eff_subst.insert(*v, u.fresh_eff_id()); }
974 subst_vars(&s.ty, &ty_subst, &eff_subst)
975}
976
977fn subst_vars(
978 t: &Ty,
979 subst: &IndexMap<TyVarId, Ty>,
980 eff_subst: &IndexMap<u32, u32>,
981) -> Ty {
982 match t {
983 Ty::Var(v) => subst.get(v).cloned().unwrap_or_else(|| Ty::Var(*v)),
984 Ty::Prim(_) | Ty::Unit | Ty::Never => t.clone(),
985 Ty::List(inner) => Ty::List(Box::new(subst_vars(inner, subst, eff_subst))),
986 Ty::Tuple(items) => Ty::Tuple(items.iter().map(|t| subst_vars(t, subst, eff_subst)).collect()),
987 Ty::Record(fs) => {
988 let mut out = IndexMap::new();
989 for (k, v) in fs { out.insert(k.clone(), subst_vars(v, subst, eff_subst)); }
990 Ty::Record(out)
991 }
992 Ty::Con(n, args) => Ty::Con(n.clone(),
993 args.iter().map(|t| subst_vars(t, subst, eff_subst)).collect()),
994 Ty::Function { params, effects, ret } => {
995 let new_effects = EffectSet {
998 concrete: effects.concrete.clone(),
999 var: effects.var.and_then(|v| eff_subst.get(&v).copied()).or(effects.var),
1000 };
1001 Ty::Function {
1002 params: params.iter().map(|t| subst_vars(t, subst, eff_subst)).collect(),
1003 effects: new_effects,
1004 ret: Box::new(subst_vars(ret, subst, eff_subst)),
1005 }
1006 }
1007 }
1008}
1009
1010fn mismatch_err(node_id: &str, e: UnifyError, u: &Unifier, context: Vec<String>) -> TypeError {
1011 match e {
1012 UnifyError::Mismatch { a, b } => TypeError::TypeMismatch {
1013 at_node: node_id.into(),
1014 expected: u.resolve(&b).pretty(),
1015 got: u.resolve(&a).pretty(),
1016 context,
1017 },
1018 UnifyError::Infinite { .. } => TypeError::InfiniteType { at_node: node_id.into() },
1019 UnifyError::EffectMismatch { a, b } => {
1020 let render = |e: &EffectSet| -> String {
1024 let mut parts: Vec<String> = e.concrete.iter().cloned().collect();
1025 if let Some(v) = e.var { parts.push(format!("?e{}", v)); }
1026 if parts.is_empty() { "[]".into() } else { format!("[{}]", parts.join(", ")) }
1027 };
1028 TypeError::TypeMismatch {
1029 at_node: node_id.into(),
1030 expected: render(&b),
1031 got: render(&a),
1032 context,
1033 }
1034 }
1035 }
1036}