1use crate::builtins::{module_for_import, module_scope};
5use crate::env::{TypeDefKind, TypeEnv, ty_from_canon};
6use crate::error::TypeError;
7use crate::types::*;
8use crate::unifier::{UnifyError, Unifier};
9use indexmap::IndexMap;
10use lex_ast as a;
11use std::collections::HashMap;
12
13pub struct ProgramTypes {
15 pub fn_signatures: IndexMap<String, Scheme>,
16 pub type_env: TypeEnv,
17 pub parse_required_fields: HashMap<usize, Vec<String>>,
26}
27
28pub fn check_program(stages: &[a::Stage]) -> Result<ProgramTypes, Vec<TypeError>> {
29 let mut tcx = Checker::new();
30 let mut errors = Vec::new();
31
32 for stage in stages {
34 if let a::Stage::Import(i) = stage {
35 if let Some(mod_name) = module_for_import(&i.reference) {
36 if let Some(ty) = module_scope(mod_name, &tcx.type_env) {
37 tcx.globals.insert(i.alias.clone(), Scheme {
38 vars: collect_vars(&ty),
42 eff_vars: collect_eff_vars(&ty),
43 ty,
44 });
45 tcx.module_aliases.insert(i.alias.clone(), mod_name.to_string());
46 }
47 }
48 }
49 }
50
51 for stage in stages {
53 if let a::Stage::TypeDecl(td) = stage {
54 if let Err(e) = tcx.type_env.add_user_type(&td.name, td.clone()) {
55 errors.push(TypeError::RecursiveTypeWithoutConstructor {
56 at_node: "n_0".into(),
57 name: e,
58 });
59 }
60 }
61 }
62
63 for stage in stages {
65 if let a::Stage::FnDecl(fd) = stage {
66 let scheme = function_scheme(fd);
67 tcx.globals.insert(fd.name.clone(), scheme);
68 tcx.fn_params.insert(fd.name.clone(), fd.params.clone());
72 }
73 }
74
75 let mut signatures = IndexMap::new();
77 for stage in stages {
78 if let a::Stage::FnDecl(fd) = stage {
79 match tcx.check_fn(fd) {
80 Ok(scheme) => { signatures.insert(fd.name.clone(), scheme); }
81 Err(es) => errors.extend(es),
82 }
83 }
84 }
85
86 if errors.is_empty() {
87 let mut parse_required_fields = HashMap::new();
93 for (call_ptr, ret_ty) in &tcx.pending_parse_calls {
94 if let Some(fields) = extract_record_fields_from_result(&tcx.u, &tcx.type_env, ret_ty) {
95 parse_required_fields.insert(*call_ptr, fields);
96 }
97 }
98 Ok(ProgramTypes {
99 fn_signatures: signatures,
100 type_env: tcx.type_env,
101 parse_required_fields,
102 })
103 } else {
104 Err(errors)
105 }
106}
107
108pub fn check_and_rewrite_program(
114 stages: &mut [a::Stage],
115) -> Result<ProgramTypes, Vec<TypeError>> {
116 let pt = check_program(&*stages)?;
121 if !pt.parse_required_fields.is_empty() {
122 rewrite_parse_calls(stages, &pt.parse_required_fields);
123 }
124 Ok(pt)
125}
126
127fn rewrite_parse_calls(stages: &mut [a::Stage], required: &HashMap<usize, Vec<String>>) {
141 for stage in stages.iter_mut() {
142 if let a::Stage::FnDecl(fd) = stage {
143 rewrite_in_expr(&mut fd.body, required);
144 }
145 }
146}
147
148fn rewrite_in_expr(expr: &mut a::CExpr, required: &HashMap<usize, Vec<String>>) {
149 let ptr = expr as *const a::CExpr as usize;
150 let do_rewrite = required.get(&ptr).cloned();
151 match expr {
157 a::CExpr::Call { callee, args } => {
158 rewrite_in_expr(callee, required);
159 for a in args.iter_mut() { rewrite_in_expr(a, required); }
160 }
161 a::CExpr::Let { value, body, .. } => {
162 rewrite_in_expr(value, required);
163 rewrite_in_expr(body, required);
164 }
165 a::CExpr::Match { scrutinee, arms } => {
166 rewrite_in_expr(scrutinee, required);
167 for arm in arms.iter_mut() { rewrite_in_expr(&mut arm.body, required); }
168 }
169 a::CExpr::Block { statements, result } => {
170 for s in statements.iter_mut() { rewrite_in_expr(s, required); }
171 rewrite_in_expr(result, required);
172 }
173 a::CExpr::Constructor { args, .. } => {
174 for a in args.iter_mut() { rewrite_in_expr(a, required); }
175 }
176 a::CExpr::RecordLit { fields } => {
177 for f in fields.iter_mut() { rewrite_in_expr(&mut f.value, required); }
178 }
179 a::CExpr::TupleLit { items } | a::CExpr::ListLit { items } => {
180 for it in items.iter_mut() { rewrite_in_expr(it, required); }
181 }
182 a::CExpr::FieldAccess { value, .. } => rewrite_in_expr(value, required),
183 a::CExpr::Lambda { body, .. } => rewrite_in_expr(body, required),
184 a::CExpr::BinOp { lhs, rhs, .. } => {
185 rewrite_in_expr(lhs, required);
186 rewrite_in_expr(rhs, required);
187 }
188 a::CExpr::UnaryOp { expr, .. } => rewrite_in_expr(expr, required),
189 a::CExpr::Return { value } => rewrite_in_expr(value, required),
190 a::CExpr::Literal { .. } | a::CExpr::Var { .. } => {}
191 }
192 if let Some(fields) = do_rewrite {
193 match expr {
194 a::CExpr::Call { callee, args } => {
195 if let a::CExpr::FieldAccess { field, .. } = callee.as_mut() {
196 debug_assert_eq!(field, "parse",
197 "rewrite_in_expr: only `.parse` calls should be in the table");
198 *field = "parse_strict".to_string();
199 }
200 args.push(a::CExpr::ListLit {
201 items: fields.into_iter()
202 .map(|f| a::CExpr::Literal {
203 value: a::CLit::Str { value: f },
204 })
205 .collect(),
206 });
207 }
208 _ => unreachable!("rewrite table key must point to a Call expression"),
209 }
210 }
211}
212
213fn extract_record_fields_from_result(
218 u: &Unifier,
219 env: &TypeEnv,
220 ty: &Ty,
221) -> Option<Vec<String>> {
222 let resolved = u.resolve(ty);
223 let Ty::Con(ref name, ref args) = resolved else { return None; };
224 if name != "Result" || args.len() != 2 { return None; }
225 let ok_ty = u.resolve(&args[0]);
226 let unfolded = unfold_record_alias_static(env, ok_ty);
227 if let Ty::Record(fields) = unfolded {
228 Some(fields.keys().cloned().collect())
229 } else {
230 None
231 }
232}
233
234fn unfold_record_alias_static(env: &TypeEnv, ty: Ty) -> Ty {
239 if let Ty::Con(ref n, _) = ty {
240 if let Some(td) = env.types.get(n) {
241 if let TypeDefKind::Alias(inner @ Ty::Record(_)) = &td.kind {
242 return inner.clone();
243 }
244 }
245 }
246 ty
247}
248
249fn collect_vars(t: &Ty) -> Vec<TyVarId> {
250 let mut out = Vec::new();
251 fn walk(t: &Ty, out: &mut Vec<TyVarId>) {
252 match t {
253 Ty::Var(v) => { if !out.contains(v) { out.push(*v); } }
254 Ty::Prim(_) | Ty::Unit | Ty::Never => {}
255 Ty::List(inner) => walk(inner, out),
256 Ty::Tuple(items) => for it in items { walk(it, out); },
257 Ty::Record(fs) => for v in fs.values() { walk(v, out); },
258 Ty::Con(_, args) => for a in args { walk(a, out); },
259 Ty::Function { params, ret, .. } => {
260 for p in params { walk(p, out); }
261 walk(ret, out);
262 }
263 }
264 }
265 walk(t, &mut out);
266 out
267}
268
269fn collect_eff_vars(t: &Ty) -> Vec<u32> {
273 let mut out = Vec::new();
274 fn walk(t: &Ty, out: &mut Vec<u32>) {
275 match t {
276 Ty::Var(_) | Ty::Prim(_) | Ty::Unit | Ty::Never => {}
277 Ty::List(inner) => walk(inner, out),
278 Ty::Tuple(items) => for it in items { walk(it, out); },
279 Ty::Record(fs) => for v in fs.values() { walk(v, out); },
280 Ty::Con(_, args) => for a in args { walk(a, out); },
281 Ty::Function { params, effects, ret } => {
282 if let Some(v) = effects.var {
283 if !out.contains(&v) { out.push(v); }
284 }
285 for p in params { walk(p, out); }
286 walk(ret, out);
287 }
288 }
289 }
290 walk(t, &mut out);
291 out
292}
293
294fn function_scheme(fd: &a::FnDecl) -> Scheme {
295 let params: Vec<Ty> = fd.params.iter().map(|p| ty_from_canon(&p.ty, &fd.type_params)).collect();
297 let ret = ty_from_canon(&fd.return_type, &fd.type_params);
298 let effects = EffectSet {
302 concrete: {
303 let mut s = std::collections::BTreeSet::new();
304 for e in &fd.effects {
305 let arg = e.arg.as_ref().map(|a| match a {
306 a::EffectArg::Str { value } => crate::types::EffectArg::Str(value.clone()),
307 a::EffectArg::Int { value } => crate::types::EffectArg::Int(*value),
308 a::EffectArg::Ident { value } => crate::types::EffectArg::Ident(value.clone()),
309 });
310 s.insert(crate::types::EffectKind { name: e.name.clone(), arg });
311 }
312 s
313 },
314 var: None,
315 };
316 let ty = Ty::Function { params, effects, ret: Box::new(ret) };
317 let vars: Vec<TyVarId> = (0..fd.type_params.len() as u32).collect();
318 Scheme { vars, eff_vars: Vec::new(), ty }
322}
323
324struct Checker {
325 u: Unifier,
326 type_env: TypeEnv,
327 globals: IndexMap<String, Scheme>,
328 module_aliases: IndexMap<String, String>,
332 pending_parse_calls: Vec<(usize, Ty)>,
340 fn_params: IndexMap<String, Vec<a::Param>>,
346}
347
348impl Checker {
349 fn new() -> Self {
350 Self {
351 u: Unifier::new(),
352 type_env: TypeEnv::new_with_builtins(),
353 globals: IndexMap::new(),
354 module_aliases: IndexMap::new(),
355 pending_parse_calls: Vec::new(),
356 fn_params: IndexMap::new(),
357 }
358 }
359
360 fn unfold_record_alias(&self, ty: Ty) -> Ty {
364 if let Ty::Con(ref n, _) = ty {
365 if let Some(td) = self.type_env.types.get(n) {
366 if let TypeDefKind::Alias(inner @ Ty::Record(_)) = &td.kind {
367 return inner.clone();
368 }
369 }
370 }
371 ty
372 }
373
374 fn is_module_parse_call(&self, callee: &a::CExpr) -> bool {
379 if let a::CExpr::FieldAccess { value, field } = callee {
380 if field != "parse" { return false; }
381 if let a::CExpr::Var { name } = value.as_ref() {
382 if let Some(module) = self.module_aliases.get(name) {
383 return matches!(module.as_str(), "json" | "toml" | "yaml");
384 }
385 }
386 }
387 false
388 }
389
390 fn unify_with_record_coercion(&mut self, a: &Ty, b: &Ty) -> Result<(), UnifyError> {
402 let a = self.u.resolve(a);
403 let b = self.u.resolve(b);
404 self.unify_coerce_inner(a, b)
405 }
406
407 fn unify_coerce_inner(&mut self, a: Ty, b: Ty) -> Result<(), UnifyError> {
408 let (a, b) = match (&a, &b) {
410 (Ty::Record(_), Ty::Con(_, _)) => (a, self.unfold_record_alias(b.clone())),
411 (Ty::Con(_, _), Ty::Record(_)) => (self.unfold_record_alias(a.clone()), b),
412 _ => (a, b),
413 };
414
415 match (&a, &b) {
416 (Ty::Record(fa), Ty::Record(fb)) => {
417 if fa.len() != fb.len() {
418 return Err(UnifyError::Mismatch { a: a.clone(), b: b.clone() });
419 }
420 for (k, va) in fa.clone() {
421 match fb.get(&k) {
422 Some(vb) => self.unify_coerce_inner(va, vb.clone())?,
423 None => return Err(UnifyError::Mismatch { a: a.clone(), b: b.clone() }),
424 }
425 }
426 Ok(())
427 }
428 (Ty::List(ta), Ty::List(tb)) => {
429 self.unify_coerce_inner((**ta).clone(), (**tb).clone())
430 }
431 (Ty::Tuple(xs), Ty::Tuple(ys)) if xs.len() == ys.len() => {
432 for (x, y) in xs.clone().into_iter().zip(ys.clone()) {
433 self.unify_coerce_inner(x, y)?;
434 }
435 Ok(())
436 }
437 _ => self.u.unify(&a, &b),
438 }
439 }
440
441 fn check_fn(&mut self, fd: &a::FnDecl) -> Result<Scheme, Vec<TypeError>> {
442 let scheme = function_scheme(fd);
444 let (param_tys, declared_effects, ret_ty) = match instantiate(&scheme, &mut self.u) {
445 Ty::Function { params, effects, ret } => (params, effects, *ret),
446 _ => unreachable!(),
447 };
448
449 let mut locals: IndexMap<String, Ty> = IndexMap::new();
450 for (p, t) in fd.params.iter().zip(param_tys.iter()) {
451 locals.insert(p.name.clone(), t.clone());
452 }
453
454 let mut inferred_effects = EffectSet::empty();
455 let body_ty = self.check_expr(&fd.body, "n_0", &mut locals, &mut inferred_effects)
456 .map_err(|e| vec![e])?;
457
458 if let Err(e) = self.unify_with_record_coercion(&body_ty, &ret_ty) {
464 return Err(vec![mismatch_err("n_0", e, &self.u, vec![format!("in function `{}`", fd.name)])]);
465 }
466
467 if !inferred_effects.is_subset(&declared_effects) {
468 for e in inferred_effects.concrete.iter() {
470 if !declared_effects.concrete.iter().any(|d| d.subsumes(e)) {
471 return Err(vec![TypeError::EffectNotDeclared {
472 at_node: "n_0".into(),
473 effect: e.pretty(),
474 }]);
475 }
476 }
477 }
478
479 Ok(scheme)
480 }
481
482 fn check_expr(
483 &mut self,
484 e: &a::CExpr,
485 node_id: &str,
486 locals: &mut IndexMap<String, Ty>,
487 effs: &mut EffectSet,
488 ) -> Result<Ty, TypeError> {
489 match e {
490 a::CExpr::Literal { value } => Ok(lit_type(value)),
491 a::CExpr::Var { name } => {
492 if let Some(t) = locals.get(name) {
493 return Ok(t.clone());
494 }
495 if let Some(scheme) = self.globals.get(name).cloned() {
496 return Ok(instantiate(&scheme, &mut self.u));
497 }
498 Err(TypeError::UnknownIdentifier { at_node: node_id.into(), name: name.clone() })
499 }
500 a::CExpr::Constructor { name, args } => self.check_constructor(name, args, node_id, locals, effs),
501 a::CExpr::Call { callee, args } => self.check_call(e, callee, args, node_id, locals, effs),
502 a::CExpr::Let { name, ty, value, body } => {
503 let v_ty = self.check_expr(value, node_id, locals, effs)?;
504 if let Some(declared) = ty {
505 let d = ty_from_canon(declared, &[]);
506 if let Err(err) = self.unify_with_record_coercion(&v_ty, &d) {
507 return Err(mismatch_err(node_id, err, &self.u, vec![format!("in let `{}`", name)]));
508 }
509 }
510 let prev = locals.insert(name.clone(), v_ty);
511 let body_ty = self.check_expr(body, node_id, locals, effs)?;
512 match prev {
513 Some(p) => { locals.insert(name.clone(), p); }
514 None => { locals.shift_remove(name); }
515 }
516 Ok(body_ty)
517 }
518 a::CExpr::Match { scrutinee, arms } => {
519 let scrut_ty = self.check_expr(scrutinee, node_id, locals, effs)?;
520 if arms.is_empty() {
521 return Err(TypeError::NonExhaustiveMatch {
522 at_node: node_id.into(), missing: vec!["_".into()]
523 });
524 }
525 let result_ty = self.u.fresh();
526 for arm in arms {
527 let mut arm_locals = locals.clone();
528 self.bind_pattern(&arm.pattern, &scrut_ty, &mut arm_locals, node_id)?;
529 let arm_ty = self.check_expr(&arm.body, node_id, &mut arm_locals, effs)?;
530 if let Err(err) = self.unify_with_record_coercion(&arm_ty, &result_ty) {
531 return Err(mismatch_err(node_id, err, &self.u, vec!["in match arm".into()]));
532 }
533 }
534 Ok(result_ty)
535 }
536 a::CExpr::Block { statements, result } => {
537 for s in statements {
538 self.check_expr(s, node_id, locals, effs)?;
539 }
540 self.check_expr(result, node_id, locals, effs)
541 }
542 a::CExpr::RecordLit { fields } => {
543 let mut tys = IndexMap::new();
544 for f in fields {
545 if tys.contains_key(&f.name) {
546 return Err(TypeError::DuplicateField {
547 at_node: node_id.into(), field: f.name.clone()
548 });
549 }
550 let ft = self.check_expr(&f.value, node_id, locals, effs)?;
551 tys.insert(f.name.clone(), ft);
552 }
553 Ok(Ty::Record(tys))
554 }
555 a::CExpr::TupleLit { items } => {
556 let mut ts = Vec::new();
557 for it in items { ts.push(self.check_expr(it, node_id, locals, effs)?); }
558 Ok(Ty::Tuple(ts))
559 }
560 a::CExpr::ListLit { items } => {
561 let elem = self.u.fresh();
562 for it in items {
563 let t = self.check_expr(it, node_id, locals, effs)?;
564 if let Err(err) = self.unify_with_record_coercion(&t, &elem) {
565 return Err(mismatch_err(node_id, err, &self.u, vec!["in list literal".into()]));
566 }
567 }
568 Ok(Ty::List(Box::new(elem)))
569 }
570 a::CExpr::FieldAccess { value, field } => {
571 let vt = self.check_expr(value, node_id, locals, effs)?;
572 let resolved = self.u.resolve(&vt);
573 let resolved = match resolved {
575 Ty::Con(ref n, _) => match self.type_env.types.get(n) {
576 Some(td) => match &td.kind {
577 TypeDefKind::Alias(inner @ Ty::Record(_)) => inner.clone(),
578 _ => resolved,
579 },
580 None => resolved,
581 },
582 other => other,
583 };
584 match resolved {
585 Ty::Record(fields) => fields.get(field).cloned()
586 .ok_or_else(|| TypeError::UnknownField {
587 at_node: node_id.into(),
588 record_type: Ty::Record(fields.clone()).pretty(),
589 field: field.clone(),
590 }),
591 other => Err(TypeError::TypeMismatch {
592 at_node: node_id.into(),
593 expected: "record".into(),
594 got: other.pretty(),
595 context: vec![format!("field access `.{}`", field)],
596 }),
597 }
598 }
599 a::CExpr::Lambda { params, return_type, effects: l_effects, body } => {
600 let param_tys: Vec<Ty> = params.iter().map(|p| ty_from_canon(&p.ty, &[])).collect();
601 let ret_ty = ty_from_canon(return_type, &[]);
602 let declared = EffectSet {
603 concrete: {
604 let mut s = std::collections::BTreeSet::new();
605 for e in l_effects {
606 let arg = e.arg.as_ref().map(|a| match a {
607 a::EffectArg::Str { value } => crate::types::EffectArg::Str(value.clone()),
608 a::EffectArg::Int { value } => crate::types::EffectArg::Int(*value),
609 a::EffectArg::Ident { value } => crate::types::EffectArg::Ident(value.clone()),
610 });
611 s.insert(crate::types::EffectKind { name: e.name.clone(), arg });
612 }
613 s
614 },
615 var: None,
616 };
617 let mut inner_locals = locals.clone();
618 for (p, t) in params.iter().zip(param_tys.iter()) {
619 inner_locals.insert(p.name.clone(), t.clone());
620 }
621 let mut inner_effs = EffectSet::empty();
622 let body_ty = self.check_expr(body, node_id, &mut inner_locals, &mut inner_effs)?;
623 if let Err(err) = self.unify_with_record_coercion(&body_ty, &ret_ty) {
624 return Err(mismatch_err(node_id, err, &self.u, vec!["in lambda body".into()]));
625 }
626 if !inner_effs.is_subset(&declared) {
627 for e in inner_effs.concrete.iter() {
628 if !declared.concrete.iter().any(|d| d.subsumes(e)) {
629 return Err(TypeError::EffectNotDeclared {
630 at_node: node_id.into(),
631 effect: e.pretty(),
632 });
633 }
634 }
635 }
636 Ok(Ty::function(param_tys, declared, ret_ty))
637 }
638 a::CExpr::BinOp { op, lhs, rhs } => self.check_binop(op, lhs, rhs, node_id, locals, effs),
639 a::CExpr::UnaryOp { op, expr } => {
640 let t = self.check_expr(expr, node_id, locals, effs)?;
641 match op.as_str() {
642 "-" => {
643 let r = self.u.resolve(&t);
645 match r {
646 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) => Ok(t),
647 Ty::Var(_) => {
648 self.u.unify(&t, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![]))?;
650 Ok(Ty::int())
651 }
652 other => Err(TypeError::TypeMismatch {
653 at_node: node_id.into(),
654 expected: "Int or Float".into(),
655 got: other.pretty(),
656 context: vec!["unary `-`".into()],
657 }),
658 }
659 }
660 "not" => {
661 self.u.unify(&t, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["unary `not`".into()]))?;
662 Ok(Ty::bool())
663 }
664 other => panic!("unknown unary op: {other}"),
665 }
666 }
667 a::CExpr::Return { value } => {
668 self.check_expr(value, node_id, locals, effs)?;
671 Ok(Ty::Never)
672 }
673 }
674 }
675
676 fn check_binop(
677 &mut self,
678 op: &str,
679 lhs: &a::CExpr,
680 rhs: &a::CExpr,
681 node_id: &str,
682 locals: &mut IndexMap<String, Ty>,
683 effs: &mut EffectSet,
684 ) -> Result<Ty, TypeError> {
685 let lt = self.check_expr(lhs, node_id, locals, effs)?;
686 let rt = self.check_expr(rhs, node_id, locals, effs)?;
687 match op {
688 "+" | "-" | "*" | "/" | "%" => {
689 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
690 let r = self.u.resolve(<);
691 match r {
692 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) => Ok(lt),
693 Ty::Var(_) => {
694 self.u.unify(<, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
695 Ok(Ty::int())
696 }
697 other => Err(TypeError::TypeMismatch {
698 at_node: node_id.into(),
699 expected: "Int or Float".into(),
700 got: other.pretty(),
701 context: vec![format!("operator `{op}`")],
702 }),
703 }
704 }
705 "==" | "!=" => {
706 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
707 Ok(Ty::bool())
708 }
709 "<" | "<=" | ">" | ">=" => {
710 self.u.unify(<, &rt).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
711 let r = self.u.resolve(<);
712 match r {
713 Ty::Prim(Prim::Int) | Ty::Prim(Prim::Float) | Ty::Prim(Prim::Str) => Ok(Ty::bool()),
714 Ty::Var(_) => {
715 self.u.unify(<, &Ty::int()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
716 Ok(Ty::bool())
717 }
718 other => Err(TypeError::TypeMismatch {
719 at_node: node_id.into(),
720 expected: "Int, Float, or Str".into(),
721 got: other.pretty(),
722 context: vec![format!("operator `{op}`")],
723 }),
724 }
725 }
726 "and" | "or" => {
727 self.u.unify(<, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
728 self.u.unify(&rt, &Ty::bool()).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("operator `{op}`")]))?;
729 Ok(Ty::bool())
730 }
731 other => panic!("unknown binop: {other}"),
732 }
733 }
734
735 fn check_call(
736 &mut self,
737 call_expr: &a::CExpr,
738 callee: &a::CExpr,
739 args: &[a::CExpr],
740 node_id: &str,
741 locals: &mut IndexMap<String, Ty>,
742 effs: &mut EffectSet,
743 ) -> Result<Ty, TypeError> {
744 let parse_call_ptr = if self.is_module_parse_call(callee) {
752 Some(call_expr as *const a::CExpr as usize)
753 } else {
754 None
755 };
756 let callee_ty = self.check_expr(callee, node_id, locals, effs)?;
757 let resolved = self.u.resolve(&callee_ty);
758 match resolved {
759 Ty::Function { params, effects, ret } => {
760 if params.len() != args.len() {
761 return Err(TypeError::ArityMismatch {
762 at_node: node_id.into(),
763 expected: params.len(),
764 got: args.len(),
765 });
766 }
767 for (i, (a, p)) in args.iter().zip(params.iter()).enumerate() {
768 let at = self.check_expr(a, node_id, locals, effs)?;
769 if let Err(err) = self.unify_with_record_coercion(&at, p) {
770 return Err(mismatch_err(node_id, err, &self.u, vec![format!("argument {} of call", i + 1)]));
771 }
772 }
773 if let a::CExpr::Var { name: callee_name } = callee {
780 if let Some(callee_params) = self.fn_params.get(callee_name).cloned() {
781 for (i, (param, arg)) in callee_params.iter().zip(args.iter()).enumerate() {
782 if let a::TypeExpr::Refined { binding, predicate, .. } = ¶m.ty {
783 let outcome = crate::discharge::try_discharge(
784 predicate, binding, arg);
785 if let crate::discharge::DischargeOutcome::Refuted { reason } = outcome {
786 return Err(TypeError::RefinementViolation {
787 at_node: node_id.into(),
788 fn_name: callee_name.clone(),
789 param_index: i,
790 binding: binding.clone(),
791 reason,
792 });
793 }
794 }
795 }
796 }
797 }
798 let resolved_effects = self.u.resolve_effects(&effects);
803 effs.extend(&resolved_effects);
804 if let Some(ptr) = parse_call_ptr {
812 self.pending_parse_calls.push((ptr, (*ret).clone()));
813 }
814 Ok(*ret)
815 }
816 Ty::Var(_) => {
817 let mut p_tys = Vec::new();
819 for a in args { p_tys.push(self.check_expr(a, node_id, locals, effs)?); }
820 let r = self.u.fresh();
821 let f = Ty::function(p_tys, EffectSet::empty(), r.clone());
822 self.u.unify(&callee_ty, &f).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in call".into()]))?;
823 Ok(r)
824 }
825 other => Err(TypeError::TypeMismatch {
826 at_node: node_id.into(),
827 expected: "function".into(),
828 got: other.pretty(),
829 context: vec!["in call".into()],
830 }),
831 }
832 }
833
834 fn check_constructor(
835 &mut self,
836 name: &str,
837 args: &[a::CExpr],
838 node_id: &str,
839 locals: &mut IndexMap<String, Ty>,
840 effs: &mut EffectSet,
841 ) -> Result<Ty, TypeError> {
842 let owning = self.type_env.ctor_to_type.get(name).cloned()
843 .ok_or_else(|| TypeError::UnknownVariant {
844 at_node: node_id.into(),
845 constructor: name.to_string(),
846 })?;
847 let def = self.type_env.types.get(&owning).cloned()
848 .expect("ctor_to_type points to a real type");
849 let variants = match &def.kind {
850 TypeDefKind::Union(v) => v.clone(),
851 _ => return Err(TypeError::UnknownVariant {
852 at_node: node_id.into(),
853 constructor: name.to_string(),
854 }),
855 };
856 let mut subst = IndexMap::new();
859 let mut con_args = Vec::with_capacity(def.params.len());
860 for (i, _p) in def.params.iter().enumerate() {
861 let fresh = self.u.fresh();
862 subst.insert(i as u32, fresh.clone());
863 con_args.push(fresh);
864 }
865 let payload = variants.get(name).cloned().flatten();
866 match (payload, args) {
867 (None, []) => Ok(Ty::Con(owning, con_args)),
868 (Some(payload), args) => {
869 let inst_payload = subst_vars(&payload, &subst, &IndexMap::new());
870 let arg_count = match &inst_payload {
871 Ty::Tuple(items) => items.len(),
872 _ => 1,
873 };
874 if arg_count != args.len() {
875 return Err(TypeError::ArityMismatch {
876 at_node: node_id.into(),
877 expected: arg_count,
878 got: args.len(),
879 });
880 }
881 if args.len() == 1 {
882 let at = self.check_expr(&args[0], node_id, locals, effs)?;
883 self.unify_with_record_coercion(&at, &inst_payload).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor `{}`", name)]))?;
884 } else if let Ty::Tuple(items) = inst_payload {
885 for (i, (a, t)) in args.iter().zip(items.iter()).enumerate() {
886 let at = self.check_expr(a, node_id, locals, effs)?;
887 self.unify_with_record_coercion(&at, t).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor `{}` arg {}", name, i + 1)]))?;
888 }
889 }
890 Ok(Ty::Con(owning, con_args))
891 }
892 (None, _) => Err(TypeError::ArityMismatch {
893 at_node: node_id.into(), expected: 0, got: args.len(),
894 }),
895 }
896 }
897
898 fn bind_pattern(
899 &mut self,
900 pat: &a::Pattern,
901 ty: &Ty,
902 locals: &mut IndexMap<String, Ty>,
903 node_id: &str,
904 ) -> Result<(), TypeError> {
905 match pat {
906 a::Pattern::PWild => Ok(()),
907 a::Pattern::PVar { name } => {
908 locals.insert(name.clone(), ty.clone());
909 Ok(())
910 }
911 a::Pattern::PLiteral { value } => {
912 let lt = lit_type(value);
913 self.unify_with_record_coercion(<, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in pattern".into()]))?;
914 Ok(())
915 }
916 a::Pattern::PConstructor { name, args } => {
917 let owning = self.type_env.ctor_to_type.get(name).cloned()
919 .ok_or_else(|| TypeError::UnknownVariant {
920 at_node: node_id.into(), constructor: name.clone(),
921 })?;
922 let def = self.type_env.types.get(&owning).cloned().unwrap();
923 let mut subst = IndexMap::new();
924 let mut con_args = Vec::new();
925 for (i, _) in def.params.iter().enumerate() {
926 let fresh = self.u.fresh();
927 subst.insert(i as u32, fresh.clone());
928 con_args.push(fresh);
929 }
930 let con_ty = Ty::Con(owning.clone(), con_args);
931 self.unify_with_record_coercion(&con_ty, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec![format!("constructor pattern `{}`", name)]))?;
932 let payload = match &def.kind {
933 TypeDefKind::Union(v) => v.get(name).cloned().flatten(),
934 _ => None,
935 };
936 match (payload, args.as_slice()) {
937 (None, []) => Ok(()),
938 (Some(payload), args) => {
939 let inst = subst_vars(&payload, &subst, &IndexMap::new());
940 if args.len() == 1 {
941 self.bind_pattern(&args[0], &inst, locals, node_id)?;
942 } else if let Ty::Tuple(items) = inst {
943 for (a, t) in args.iter().zip(items.iter()) {
944 self.bind_pattern(a, t, locals, node_id)?;
945 }
946 }
947 Ok(())
948 }
949 (None, _) => Err(TypeError::ArityMismatch {
950 at_node: node_id.into(), expected: 0, got: args.len(),
951 }),
952 }
953 }
954 a::Pattern::PRecord { fields } => {
955 let resolved = self.unfold_record_alias(self.u.resolve(ty));
960 let rec = match resolved {
961 Ty::Record(r) => r,
962 _ => return Err(TypeError::TypeMismatch {
963 at_node: node_id.into(),
964 expected: "record".into(),
965 got: ty.pretty(),
966 context: vec!["in record pattern".into()],
967 }),
968 };
969 for f in fields {
970 let ft = rec.get(&f.name).cloned()
971 .ok_or_else(|| TypeError::UnknownField {
972 at_node: node_id.into(),
973 record_type: Ty::Record(rec.clone()).pretty(),
974 field: f.name.clone(),
975 })?;
976 self.bind_pattern(&f.pattern, &ft, locals, node_id)?;
977 }
978 Ok(())
979 }
980 a::Pattern::PTuple { items } => {
981 let resolved = self.u.resolve(ty);
982 let tup = match resolved {
983 Ty::Tuple(t) => t,
984 Ty::Var(_) => {
985 let fresh: Vec<Ty> = items.iter().map(|_| self.u.fresh()).collect();
986 let tup_ty = Ty::Tuple(fresh.clone());
987 self.unify_with_record_coercion(&tup_ty, ty).map_err(|e| mismatch_err(node_id, e, &self.u, vec!["in tuple pattern".into()]))?;
988 fresh
989 }
990 other => return Err(TypeError::TypeMismatch {
991 at_node: node_id.into(),
992 expected: "tuple".into(),
993 got: other.pretty(),
994 context: vec!["in tuple pattern".into()],
995 }),
996 };
997 if tup.len() != items.len() {
998 return Err(TypeError::ArityMismatch {
999 at_node: node_id.into(), expected: tup.len(), got: items.len(),
1000 });
1001 }
1002 for (p, t) in items.iter().zip(tup.iter()) {
1003 self.bind_pattern(p, t, locals, node_id)?;
1004 }
1005 Ok(())
1006 }
1007 }
1008 }
1009}
1010
1011fn lit_type(l: &a::CLit) -> Ty {
1012 match l {
1013 a::CLit::Int { .. } => Ty::int(),
1014 a::CLit::Float { .. } => Ty::float(),
1015 a::CLit::Str { .. } => Ty::str(),
1016 a::CLit::Bytes { .. } => Ty::bytes(),
1017 a::CLit::Bool { .. } => Ty::bool(),
1018 a::CLit::Unit => Ty::Unit,
1019 }
1020}
1021
1022fn instantiate(s: &Scheme, u: &mut Unifier) -> Ty {
1023 let mut ty_subst = IndexMap::new();
1024 for v in &s.vars { ty_subst.insert(*v, u.fresh()); }
1025 let mut eff_subst = IndexMap::new();
1026 for v in &s.eff_vars { eff_subst.insert(*v, u.fresh_eff_id()); }
1027 subst_vars(&s.ty, &ty_subst, &eff_subst)
1028}
1029
1030fn subst_vars(
1031 t: &Ty,
1032 subst: &IndexMap<TyVarId, Ty>,
1033 eff_subst: &IndexMap<u32, u32>,
1034) -> Ty {
1035 match t {
1036 Ty::Var(v) => subst.get(v).cloned().unwrap_or_else(|| Ty::Var(*v)),
1037 Ty::Prim(_) | Ty::Unit | Ty::Never => t.clone(),
1038 Ty::List(inner) => Ty::List(Box::new(subst_vars(inner, subst, eff_subst))),
1039 Ty::Tuple(items) => Ty::Tuple(items.iter().map(|t| subst_vars(t, subst, eff_subst)).collect()),
1040 Ty::Record(fs) => {
1041 let mut out = IndexMap::new();
1042 for (k, v) in fs { out.insert(k.clone(), subst_vars(v, subst, eff_subst)); }
1043 Ty::Record(out)
1044 }
1045 Ty::Con(n, args) => Ty::Con(n.clone(),
1046 args.iter().map(|t| subst_vars(t, subst, eff_subst)).collect()),
1047 Ty::Function { params, effects, ret } => {
1048 let new_effects = EffectSet {
1051 concrete: effects.concrete.clone(),
1052 var: effects.var.and_then(|v| eff_subst.get(&v).copied()).or(effects.var),
1053 };
1054 Ty::Function {
1055 params: params.iter().map(|t| subst_vars(t, subst, eff_subst)).collect(),
1056 effects: new_effects,
1057 ret: Box::new(subst_vars(ret, subst, eff_subst)),
1058 }
1059 }
1060 }
1061}
1062
1063fn mismatch_err(node_id: &str, e: UnifyError, u: &Unifier, context: Vec<String>) -> TypeError {
1064 match e {
1065 UnifyError::Mismatch { a, b } => TypeError::TypeMismatch {
1066 at_node: node_id.into(),
1067 expected: u.resolve(&b).pretty(),
1068 got: u.resolve(&a).pretty(),
1069 context,
1070 },
1071 UnifyError::Infinite { .. } => TypeError::InfiniteType { at_node: node_id.into() },
1072 UnifyError::EffectMismatch { a, b } => {
1073 let render = |e: &EffectSet| -> String {
1077 let mut parts: Vec<String> = e.concrete.iter()
1078 .map(crate::types::EffectKind::pretty).collect();
1079 if let Some(v) = e.var { parts.push(format!("?e{}", v)); }
1080 if parts.is_empty() { "[]".into() } else { format!("[{}]", parts.join(", ")) }
1081 };
1082 TypeError::TypeMismatch {
1083 at_node: node_id.into(),
1084 expected: render(&b),
1085 got: render(&a),
1086 context,
1087 }
1088 }
1089 }
1090}