1use super::*;
2
3use crate::utils::prim::Prim;
4use crate::utils::unify::*;
5
6#[derive(Clone, Debug)]
7struct FuncTyScm {
8 polys: Vec<Ident>,
9 pars: Vec<TermType>,
10 res: TermType,
11}
12
13#[derive(Clone, Debug)]
14struct ConsTyScm {
15 polys: Vec<Ident>,
16 flds: Vec<TermType>,
17 res: TermType,
18}
19
20#[allow(unused)]
21#[derive(Clone, Debug)]
22struct DataTyScm {
23 polys: Vec<Ident>,
24}
25
26#[derive(Clone, Debug, PartialEq)]
27pub enum CheckError {
28 UnifyFailed {
29 typ1: TermType,
30 typ2: TermType,
31 span: Span,
32 },
33 OccurCheckFailed {
34 var: Ident,
35 typ: TermType,
36 span: Span,
37 },
38 UnifyVecDiffLen {
39 vec1: Vec<TermType>,
40 vec2: Vec<TermType>,
41 span: Span,
42 },
43 TypeArityMismatch {
44 actual: usize,
45 expected: usize,
46 span: Span,
47 },
48}
49
50use crate::cli::diagnostic::Diagnostic;
51impl From<CheckError> for Diagnostic {
52 fn from(val: CheckError) -> Self {
53 match val {
54 CheckError::UnifyFailed {typ1, typ2, span } => {
55 Diagnostic::error("cannot match type!".to_string()).line_span(
56 span.clone(),
57 format!("the expression here has type {typ1}, but expected {typ2}."),
58 )
59 }
60 CheckError::OccurCheckFailed { var, typ, span } => {
61 Diagnostic::error("occurrence check failed!".to_string()).line_span(
62 span.clone(),
63 format!("failed to unify the variable {var} with type {typ}, since it occurs in its own type."),
64 )
65 }
66 CheckError::UnifyVecDiffLen { vec1, vec2, span } => {
67 Diagnostic::error("type vectors have different length!".to_string()).line_span(
68 span.clone(),
69 format!("failed to unify two vectors with lengths: {vec1:?} and {vec2:?}"),
70 )
71 }
72 CheckError::TypeArityMismatch { actual, expected, span } => {
73 Diagnostic::error("type arity mismatch!".to_string()).line_span(
74 span.clone(),
75 format!("the type constructor has arity {actual}, but expected arity {expected}."),
76 )
77 }
78 }
79 }
80}
81
82struct Checker {
83 val_ctx: HashMap<Ident, TermType>,
84 func_ctx: HashMap<Ident, FuncTyScm>,
85 cons_ctx: HashMap<Ident, ConsTyScm>,
86 data_ctx: HashMap<Ident, DataTyScm>,
87 unifier: Unifier<Ident, LitType, OptCons<Ident>>,
88 errors: Vec<CheckError>,
89}
90
91impl Checker {
92 pub fn new() -> Checker {
93 Checker {
94 val_ctx: HashMap::new(),
95 func_ctx: HashMap::new(),
96 cons_ctx: HashMap::new(),
97 data_ctx: HashMap::new(),
98 unifier: Unifier::new(),
99 errors: Vec::new(),
100 }
101 }
102
103 fn fresh(&mut self) -> TermType {
104 TermType::Var(Ident::fresh(&"a"))
105 }
106
107 fn unify(&mut self, typ1: &TermType, typ2: &TermType, span: &Span) {
108 match self.unifier.unify(typ1, typ2) {
109 Ok(()) => {}
110 Err(UnifyError::UnifyFailed(typ1, typ2)) => {
111 self.errors.push(CheckError::UnifyFailed {
112 typ1,
113 typ2,
114 span: span.clone(),
115 });
116 }
117 Err(UnifyError::OccurCheckFailed(var, typ)) => {
118 self.errors.push(CheckError::OccurCheckFailed {
119 var,
120 typ,
121 span: span.clone(),
122 });
123 }
124 Err(UnifyError::UnifyVecDiffLen(vec1, vec2)) => {
125 self.errors.push(CheckError::UnifyVecDiffLen {
126 vec1,
127 vec2,
128 span: span.clone(),
129 });
130 }
131 }
132 }
133
134 fn unify_many(&mut self, vec1: &[TermType], vec2: &[(TermType, Span)], span: &Span) {
135 if vec1.len() == vec2.len() {
136 for (lhs, (rhs, span)) in vec1.iter().zip(vec2.iter()) {
137 self.unify(lhs, rhs, span);
138 }
139 } else {
140 self.errors.push(CheckError::UnifyVecDiffLen {
141 vec1: vec1.to_vec(),
142 vec2: vec2.iter().map(|x| x.0.clone()).collect(),
143 span: span.clone(),
144 });
145 }
146 }
147
148 fn check_prim(&mut self, prim: &Prim, args: &[Expr], span: &Span) -> TermType {
149 let args: Vec<_> = args
150 .iter()
151 .map(|arg| (self.infer_expr(arg), arg.get_span()))
152 .collect();
153
154 match prim {
155 Prim::IAdd | Prim::ISub | Prim::IMul | Prim::IDiv | Prim::IRem => {
156 self.unify_many(
157 &[TermType::Lit(LitType::TyInt), TermType::Lit(LitType::TyInt)],
158 &args,
159 span,
160 );
161 TermType::Lit(LitType::TyInt)
162 }
163 Prim::INeg => {
164 self.unify_many(&[TermType::Lit(LitType::TyInt)], &args, span);
165 TermType::Lit(LitType::TyInt)
166 }
167 Prim::ICmp(_) => {
168 self.unify_many(
169 &[TermType::Lit(LitType::TyInt), TermType::Lit(LitType::TyInt)],
170 &args,
171 span,
172 );
173 TermType::Lit(LitType::TyBool)
174 }
175 Prim::BAnd | Prim::BOr => {
176 self.unify_many(
177 &[
178 TermType::Lit(LitType::TyBool),
179 TermType::Lit(LitType::TyBool),
180 ],
181 &args,
182 span,
183 );
184 TermType::Lit(LitType::TyBool)
185 }
186 Prim::BNot => {
187 self.unify_many(&[TermType::Lit(LitType::TyBool)], &args, span);
188 TermType::Lit(LitType::TyBool)
189 }
190 }
191 }
192
193 fn infer_expr(&mut self, expr: &Expr) -> TermType {
194 match expr {
195 Expr::Lit { lit, span: _ } => TermType::Lit(lit.get_typ()),
196 Expr::Var { var, span: _ } => self.val_ctx[&var.ident].clone(),
197 Expr::Prim { prim, args, span } => self.check_prim(prim, args, span),
198 Expr::Cons { cons, flds, span } => {
199 let cons_scm = &self.cons_ctx[&cons.ident];
201
202 let inst_map: HashMap<Ident, TermType> = cons_scm
203 .polys
204 .iter()
205 .map(|poly| (*poly, Term::Var(poly.uniquify())))
206 .collect();
207
208 let inst_flds: Vec<_> = cons_scm
209 .flds
210 .iter()
211 .map(|fld| fld.substitute(&inst_map))
212 .collect();
213
214 let inst_res = cons_scm.res.substitute(&inst_map);
215
216 let flds: Vec<_> = flds
217 .iter()
218 .map(|fld| (self.infer_expr(fld), fld.get_span()))
219 .collect();
220
221 self.unify_many(&inst_flds, &flds, span);
222 inst_res
223 }
224 Expr::Tuple { flds, span: _ } => {
225 let flds: Vec<TermType> = flds.iter().map(|fld| self.infer_expr(fld)).collect();
226 TermType::Cons(OptCons::None, flds)
227 }
228 Expr::Match {
229 expr,
230 brchs,
231 span: _,
232 } => {
233 let expr_ty = self.infer_expr(expr);
234 let res = self.fresh();
235 for (patn, cont) in brchs {
236 let patn_ty = self.check_patn(patn);
237 let patn_span = patn.get_span();
238 self.unify(&patn_ty, &expr_ty, &patn_span);
239 let cont_ty = self.infer_expr(cont);
240 let cont_span = cont.get_span();
241 self.unify(&res, &cont_ty, &cont_span);
242 }
243 res
244 }
245 Expr::Let {
246 patn,
247 expr,
248 cont,
249 span: _,
250 } => {
251 let expr_ty = self.infer_expr(expr);
252 let expr_span = expr.get_span();
253 let patn_ty = self.check_patn(patn);
254 self.unify(&patn_ty, &expr_ty, &expr_span);
255 self.infer_expr(cont)
256 }
257 Expr::App { func, args, span } => {
258 let func_scm = &self.func_ctx[&func.ident];
260
261 let inst_map: HashMap<Ident, TermType> = func_scm
262 .polys
263 .iter()
264 .map(|poly| (*poly, Term::Var(poly.uniquify())))
265 .collect();
266
267 let inst_pars: Vec<_> = func_scm
268 .pars
269 .iter()
270 .map(|par| par.substitute(&inst_map))
271 .collect();
272
273 let inst_res = func_scm.res.substitute(&inst_map);
274
275 let args: Vec<_> = args
276 .iter()
277 .map(|arg| (self.infer_expr(arg), arg.get_span()))
278 .collect();
279
280 self.unify_many(&inst_pars, &args, span);
281 inst_res
282 }
283 Expr::Ifte {
284 cond,
285 then,
286 els,
287 span: _,
288 } => {
289 let cond_ty = self.infer_expr(cond);
290 let cond_span = cond.get_span();
291 self.unify(&cond_ty, &TermType::Lit(LitType::TyBool), &cond_span);
292 let then_ty = self.infer_expr(then);
293 let els_ty = self.infer_expr(els);
294 let els_span = els.get_span();
295 self.unify(&then_ty, &els_ty, &els_span);
296 then_ty
297 }
298 Expr::Cond { brchs, span: _ } => {
299 let res = self.fresh();
300 for (cond, body) in brchs {
301 let cond_ty = self.infer_expr(cond);
302 let cond_span = cond.get_span();
303 let body_ty = self.infer_expr(body);
304 let body_span = body.get_span();
305 self.unify(&cond_ty, &TermType::Lit(LitType::TyBool), &cond_span);
306 self.unify(&body_ty, &res, &body_span);
307 }
308 res
309 }
310 Expr::Alter { brchs, span: _ } => {
311 let res = self.fresh();
312 for body in brchs {
313 let body_ty = self.infer_expr(body);
314 let body_span = body.get_span();
315 self.unify(&body_ty, &res, &body_span);
316 }
317 res
318 }
319 Expr::Fresh {
320 vars,
321 cont,
322 span: _,
323 } => {
324 for var in vars {
325 let cell = self.fresh();
326 self.val_ctx.insert(var.ident, cell);
327 }
328 self.infer_expr(cont)
329 }
330 Expr::Guard {
331 lhs,
332 rhs,
333 cont,
334 span: _,
335 } => {
336 let lhs_ty = self.infer_expr(lhs);
337 if let Some(rhs) = rhs {
338 let rhs_ty = self.infer_expr(rhs);
339 let rhs_span = rhs.get_span();
340 self.unify(&lhs_ty, &rhs_ty, &rhs_span);
341 } else {
342 let lhs_span = lhs.get_span();
343 self.unify(
344 &lhs_ty,
345 &TermType::Cons(OptCons::None, Vec::new()),
346 &lhs_span,
347 );
348 }
349 self.infer_expr(cont)
350 }
351 Expr::Undefined { span: _ } => self.fresh(),
352 }
353 }
354
355 fn check_patn(&mut self, patn: &Pattern) -> TermType {
356 match patn {
357 Pattern::Lit { lit, span: _ } => TermType::Lit(lit.get_typ()),
358 Pattern::Var { var, span: _ } => {
359 let ty = self.fresh();
360 self.val_ctx.insert(var.ident, ty.clone());
361 ty
362 }
363 Pattern::Cons { cons, flds, span } => {
364 let cons_scm = &self.cons_ctx[&cons.ident];
366
367 let inst_map: HashMap<Ident, TermType> = cons_scm
368 .polys
369 .iter()
370 .map(|poly| (*poly, Term::Var(poly.uniquify())))
371 .collect();
372
373 let inst_flds: Vec<_> = cons_scm
374 .flds
375 .iter()
376 .map(|fld| fld.substitute(&inst_map))
377 .collect();
378
379 let inst_res = cons_scm.res.substitute(&inst_map);
380
381 let flds: Vec<_> = flds
382 .iter()
383 .map(|fld| (self.check_patn(fld), fld.get_span()))
384 .collect();
385
386 self.unify_many(&inst_flds, &flds, span);
387 inst_res
388 }
389 Pattern::Tuple { flds, span: _ } => {
390 let typs: Vec<TermType> = flds.iter().map(|fld| self.check_patn(fld)).collect();
391 TermType::Cons(OptCons::None, typs)
392 }
393 }
394 }
395
396 fn check_type(&mut self, typ: &Type) -> TermType {
397 match typ {
398 Type::Lit { lit, span: _ } => Term::Lit(*lit),
399 Type::Var { var, span: _ } => Term::Var(var.ident),
400 Type::Cons {
401 cons,
402 flds,
403 span: _,
404 } => {
405 let flds: Vec<_> = flds.iter().map(|fld| self.check_type(fld)).collect();
406 let data_scm = &self.data_ctx[&cons.ident];
407 if flds.len() != data_scm.polys.len() {
408 self.errors.push(CheckError::TypeArityMismatch {
409 actual: flds.len(),
410 expected: data_scm.polys.len(),
411 span: typ.get_span(),
412 });
413 }
414 Term::Cons(OptCons::Some(cons.ident), flds)
415 }
416 Type::Tuple { flds, span: _ } => {
417 let flds: Vec<TermType> = flds.iter().map(|fld| self.check_type(fld)).collect();
418 Term::Cons(OptCons::None, flds)
419 }
420 }
421 }
422
423 fn scan_data_ty_scm(&mut self, data_decl: &DataDecl) {
424 for poly in &data_decl.polys {
425 self.unifier.fresh(poly.ident);
426 }
427 let data_scm = DataTyScm {
428 polys: data_decl.polys.iter().map(|poly| poly.ident).collect(),
429 };
430 self.data_ctx.insert(data_decl.name.ident, data_scm);
431 }
432
433 fn scan_cons_ty_scm(&mut self, data_decl: &DataDecl) {
434 let res = TermType::Cons(
435 OptCons::Some(data_decl.name.ident),
436 data_decl
437 .polys
438 .iter()
439 .map(|poly| TermType::Var(poly.ident))
440 .collect(),
441 );
442
443 for cons in &data_decl.cons {
444 let flds = cons.flds.iter().map(|fld| self.check_type(fld)).collect();
445 let cons_typ = ConsTyScm {
446 polys: data_decl.polys.iter().map(|poly| poly.ident).collect(),
447 flds,
448 res: res.clone(),
449 };
450 self.cons_ctx.insert(cons.name.ident, cons_typ);
451 }
452 }
453
454 fn scan_func_ty_scm(&mut self, func_decl: &FuncDecl) {
455 for poly in &func_decl.polys {
456 self.unifier.fresh(poly.ident);
457 }
458
459 let polys = func_decl.polys.iter().map(|poly| poly.ident).collect();
460 let pars = func_decl
461 .pars
462 .iter()
463 .map(|(_par, typ)| self.check_type(typ))
464 .collect();
465
466 let res = self.check_type(&func_decl.res);
467 let func_scm = FuncTyScm { polys, pars, res };
468 self.func_ctx.insert(func_decl.name.ident, func_scm);
469 }
470
471 fn check_func_decl(&mut self, func_decl: &FuncDecl) {
472 let func_scm = self.func_ctx[&func_decl.name.ident].clone();
473 for ((par, _), par_ty) in func_decl.pars.iter().zip(func_scm.pars.iter()) {
474 self.val_ctx.insert(par.ident, par_ty.clone());
475 }
476 let body_ty = self.infer_expr(&func_decl.body);
477 let body_span = func_decl.body.get_span();
478 self.unify(&func_scm.res, &body_ty, &body_span);
479 }
480
481 fn check_prog(&mut self, prog: &Program) {
482 for data_decl in &prog.datas {
483 self.scan_data_ty_scm(data_decl);
484 }
485
486 for data_decl in &prog.datas {
487 self.scan_cons_ty_scm(data_decl);
488 }
489
490 for func_decl in &prog.funcs {
491 self.scan_func_ty_scm(func_decl);
492 }
493
494 for func_decl in &prog.funcs {
495 self.check_func_decl(func_decl);
496 }
497 }
498}
499
500pub fn check_pass(prog: &Program) -> Vec<CheckError> {
501 let mut pass = Checker::new();
502 pass.check_prog(prog);
503 let mut errors = std::mem::take(&mut pass.errors);
504 for err in &mut errors {
505 match err {
506 CheckError::UnifyFailed {
507 typ1,
508 typ2,
509 span: _,
510 } => {
511 *typ1 = pass.unifier.subst(typ1);
512 *typ2 = pass.unifier.subst(typ2);
513 }
514 CheckError::OccurCheckFailed {
515 var: _,
516 typ,
517 span: _,
518 } => {
519 *typ = pass.unifier.subst(typ);
520 }
521 CheckError::UnifyVecDiffLen {
522 vec1,
523 vec2,
524 span: _,
525 } => {
526 *vec1 = vec1.iter().map(|t| pass.unifier.subst(t)).collect();
527 *vec2 = vec2.iter().map(|t| pass.unifier.subst(t)).collect();
528 }
529 CheckError::TypeArityMismatch {
530 actual: _,
531 expected: _,
532 span: _,
533 } => {
534 }
536 }
537 }
538 errors
539}
540
541#[test]
542#[ignore = "just to see result"]
543fn check_test() {
544 let src: &'static str = r#"
545datatype List[a] where
546| Cons(a, List[a])
547| Nil
548end
549
550function append[a](xs: List[a], x: a) -> List[a]
551begin
552 match xs with
553 | Cons(head, tail) => Cons(head, append(tail, x))
554 | Nil => Cons(x, Nil)
555 end
556end
557
558function is_elem(xs: List[Int], x: Int) -> Bool
559begin
560 match xs with
561 | Cons(head, tail) => if head == x then true else is_elem(tail, x)
562 | Nil => false
563 end
564end
565
566function is_elem_after_append(xs: List[Int], x: Int)
567begin
568 guard is_elem(append(xs, x), x) = false;
569end
570
571query is_elem_after_append(depth_step=5, depth_limit=50, answer_limit=1)
572"#;
573 let (mut prog, errs) = crate::syntax::parser::parse_program(src);
574 assert!(errs.is_empty());
575
576 let errs = crate::tych::rename::rename_pass(&mut prog);
577 assert!(errs.is_empty());
578
579 let errs = check_pass(&prog);
582 assert!(errs.is_empty());
583
584 for err in errs {
588 let diag: Diagnostic = err.into();
589 println!("{}", diag.report(src, 10));
590 }
591
592 }