Skip to main content

prune_lang/tych/
check.rs

1use super::*;
2
3use crate::utils::prim::Prim;
4use crate::utils::unify::*;
5
6#[derive(Clone, Debug)]
7struct FuncTyScm {
8    polys: Vec<Ident>,
9    pars: Vec<TermType>,
10    res: TermType,
11}
12
13#[derive(Clone, Debug)]
14struct ConsTyScm {
15    polys: Vec<Ident>,
16    flds: Vec<TermType>,
17    res: TermType,
18}
19
20#[allow(unused)]
21#[derive(Clone, Debug)]
22struct DataTyScm {
23    polys: Vec<Ident>,
24}
25
26#[derive(Clone, Debug, PartialEq)]
27pub enum CheckError {
28    UnifyFailed {
29        typ1: TermType,
30        typ2: TermType,
31        span: Span,
32    },
33    OccurCheckFailed {
34        var: Ident,
35        typ: TermType,
36        span: Span,
37    },
38    UnifyVecDiffLen {
39        vec1: Vec<TermType>,
40        vec2: Vec<TermType>,
41        span: Span,
42    },
43    TypeArityMismatch {
44        actual: usize,
45        expected: usize,
46        span: Span,
47    },
48}
49
50use crate::cli::diagnostic::Diagnostic;
51impl From<CheckError> for Diagnostic {
52    fn from(val: CheckError) -> Self {
53        match val {
54            CheckError::UnifyFailed {typ1, typ2, span } => {
55                Diagnostic::error("cannot match type!".to_string()).line_span(
56                    span.clone(),
57                    format!("the expression here has type {typ1}, but expected {typ2}."),
58                )
59            }
60            CheckError::OccurCheckFailed { var, typ, span } => {
61                Diagnostic::error("occurrence check failed!".to_string()).line_span(
62                    span.clone(),
63                    format!("failed to unify the variable {var} with type {typ}, since it occurs in its own type."),
64                )
65            }
66            CheckError::UnifyVecDiffLen { vec1, vec2, span } => {
67                Diagnostic::error("type vectors have different length!".to_string()).line_span(
68                    span.clone(),
69                    format!("failed to unify two vectors with lengths: {vec1:?} and {vec2:?}"),
70                )
71            }
72            CheckError::TypeArityMismatch { actual, expected, span } => {
73                Diagnostic::error("type arity mismatch!".to_string()).line_span(
74                    span.clone(),
75                    format!("the type constructor has arity {actual}, but expected arity {expected}."),
76                )
77            }
78        }
79    }
80}
81
82struct Checker {
83    val_ctx: HashMap<Ident, TermType>,
84    func_ctx: HashMap<Ident, FuncTyScm>,
85    cons_ctx: HashMap<Ident, ConsTyScm>,
86    data_ctx: HashMap<Ident, DataTyScm>,
87    unifier: Unifier<Ident, LitType, OptCons<Ident>>,
88    errors: Vec<CheckError>,
89}
90
91impl Checker {
92    pub fn new() -> Checker {
93        Checker {
94            val_ctx: HashMap::new(),
95            func_ctx: HashMap::new(),
96            cons_ctx: HashMap::new(),
97            data_ctx: HashMap::new(),
98            unifier: Unifier::new(),
99            errors: Vec::new(),
100        }
101    }
102
103    fn fresh(&mut self) -> TermType {
104        TermType::Var(Ident::fresh(&"a"))
105    }
106
107    fn unify(&mut self, typ1: &TermType, typ2: &TermType, span: &Span) {
108        match self.unifier.unify(typ1, typ2) {
109            Ok(()) => {}
110            Err(UnifyError::UnifyFailed(typ1, typ2)) => {
111                self.errors.push(CheckError::UnifyFailed {
112                    typ1,
113                    typ2,
114                    span: span.clone(),
115                });
116            }
117            Err(UnifyError::OccurCheckFailed(var, typ)) => {
118                self.errors.push(CheckError::OccurCheckFailed {
119                    var,
120                    typ,
121                    span: span.clone(),
122                });
123            }
124            Err(UnifyError::UnifyVecDiffLen(vec1, vec2)) => {
125                self.errors.push(CheckError::UnifyVecDiffLen {
126                    vec1,
127                    vec2,
128                    span: span.clone(),
129                });
130            }
131        }
132    }
133
134    fn unify_many(&mut self, vec1: &[TermType], vec2: &[(TermType, Span)], span: &Span) {
135        if vec1.len() == vec2.len() {
136            for (lhs, (rhs, span)) in vec1.iter().zip(vec2.iter()) {
137                self.unify(lhs, rhs, span);
138            }
139        } else {
140            self.errors.push(CheckError::UnifyVecDiffLen {
141                vec1: vec1.to_vec(),
142                vec2: vec2.iter().map(|x| x.0.clone()).collect(),
143                span: span.clone(),
144            });
145        }
146    }
147
148    fn check_prim(&mut self, prim: &Prim, args: &[Expr], span: &Span) -> TermType {
149        let args: Vec<_> = args
150            .iter()
151            .map(|arg| (self.infer_expr(arg), arg.get_span()))
152            .collect();
153
154        match prim {
155            Prim::IAdd | Prim::ISub | Prim::IMul | Prim::IDiv | Prim::IRem => {
156                self.unify_many(
157                    &[TermType::Lit(LitType::TyInt), TermType::Lit(LitType::TyInt)],
158                    &args,
159                    span,
160                );
161                TermType::Lit(LitType::TyInt)
162            }
163            Prim::INeg => {
164                self.unify_many(&[TermType::Lit(LitType::TyInt)], &args, span);
165                TermType::Lit(LitType::TyInt)
166            }
167            Prim::ICmp(_) => {
168                self.unify_many(
169                    &[TermType::Lit(LitType::TyInt), TermType::Lit(LitType::TyInt)],
170                    &args,
171                    span,
172                );
173                TermType::Lit(LitType::TyBool)
174            }
175            Prim::BAnd | Prim::BOr => {
176                self.unify_many(
177                    &[
178                        TermType::Lit(LitType::TyBool),
179                        TermType::Lit(LitType::TyBool),
180                    ],
181                    &args,
182                    span,
183                );
184                TermType::Lit(LitType::TyBool)
185            }
186            Prim::BNot => {
187                self.unify_many(&[TermType::Lit(LitType::TyBool)], &args, span);
188                TermType::Lit(LitType::TyBool)
189            }
190        }
191    }
192
193    fn infer_expr(&mut self, expr: &Expr) -> TermType {
194        match expr {
195            Expr::Lit { lit, span: _ } => TermType::Lit(lit.get_typ()),
196            Expr::Var { var, span: _ } => self.val_ctx[&var.ident].clone(),
197            Expr::Prim { prim, args, span } => self.check_prim(prim, args, span),
198            Expr::Cons { cons, flds, span } => {
199                // instantiate constructor type scheme
200                let cons_scm = &self.cons_ctx[&cons.ident];
201
202                let inst_map: HashMap<Ident, TermType> = cons_scm
203                    .polys
204                    .iter()
205                    .map(|poly| (*poly, Term::Var(poly.uniquify())))
206                    .collect();
207
208                let inst_flds: Vec<_> = cons_scm
209                    .flds
210                    .iter()
211                    .map(|fld| fld.substitute(&inst_map))
212                    .collect();
213
214                let inst_res = cons_scm.res.substitute(&inst_map);
215
216                let flds: Vec<_> = flds
217                    .iter()
218                    .map(|fld| (self.infer_expr(fld), fld.get_span()))
219                    .collect();
220
221                self.unify_many(&inst_flds, &flds, span);
222                inst_res
223            }
224            Expr::Tuple { flds, span: _ } => {
225                let flds: Vec<TermType> = flds.iter().map(|fld| self.infer_expr(fld)).collect();
226                TermType::Cons(OptCons::None, flds)
227            }
228            Expr::Match {
229                expr,
230                brchs,
231                span: _,
232            } => {
233                let expr_ty = self.infer_expr(expr);
234                let res = self.fresh();
235                for (patn, cont) in brchs {
236                    let patn_ty = self.check_patn(patn);
237                    let patn_span = patn.get_span();
238                    self.unify(&patn_ty, &expr_ty, &patn_span);
239                    let cont_ty = self.infer_expr(cont);
240                    let cont_span = cont.get_span();
241                    self.unify(&res, &cont_ty, &cont_span);
242                }
243                res
244            }
245            Expr::Let {
246                patn,
247                expr,
248                cont,
249                span: _,
250            } => {
251                let expr_ty = self.infer_expr(expr);
252                let expr_span = expr.get_span();
253                let patn_ty = self.check_patn(patn);
254                self.unify(&patn_ty, &expr_ty, &expr_span);
255                self.infer_expr(cont)
256            }
257            Expr::App { func, args, span } => {
258                // instantiate predicate type scheme
259                let func_scm = &self.func_ctx[&func.ident];
260
261                let inst_map: HashMap<Ident, TermType> = func_scm
262                    .polys
263                    .iter()
264                    .map(|poly| (*poly, Term::Var(poly.uniquify())))
265                    .collect();
266
267                let inst_pars: Vec<_> = func_scm
268                    .pars
269                    .iter()
270                    .map(|par| par.substitute(&inst_map))
271                    .collect();
272
273                let inst_res = func_scm.res.substitute(&inst_map);
274
275                let args: Vec<_> = args
276                    .iter()
277                    .map(|arg| (self.infer_expr(arg), arg.get_span()))
278                    .collect();
279
280                self.unify_many(&inst_pars, &args, span);
281                inst_res
282            }
283            Expr::Ifte {
284                cond,
285                then,
286                els,
287                span: _,
288            } => {
289                let cond_ty = self.infer_expr(cond);
290                let cond_span = cond.get_span();
291                self.unify(&cond_ty, &TermType::Lit(LitType::TyBool), &cond_span);
292                let then_ty = self.infer_expr(then);
293                let els_ty = self.infer_expr(els);
294                let els_span = els.get_span();
295                self.unify(&then_ty, &els_ty, &els_span);
296                then_ty
297            }
298            Expr::Cond { brchs, span: _ } => {
299                let res = self.fresh();
300                for (cond, body) in brchs {
301                    let cond_ty = self.infer_expr(cond);
302                    let cond_span = cond.get_span();
303                    let body_ty = self.infer_expr(body);
304                    let body_span = body.get_span();
305                    self.unify(&cond_ty, &TermType::Lit(LitType::TyBool), &cond_span);
306                    self.unify(&body_ty, &res, &body_span);
307                }
308                res
309            }
310            Expr::Alter { brchs, span: _ } => {
311                let res = self.fresh();
312                for body in brchs {
313                    let body_ty = self.infer_expr(body);
314                    let body_span = body.get_span();
315                    self.unify(&body_ty, &res, &body_span);
316                }
317                res
318            }
319            Expr::Fresh {
320                vars,
321                cont,
322                span: _,
323            } => {
324                for var in vars {
325                    let cell = self.fresh();
326                    self.val_ctx.insert(var.ident, cell);
327                }
328                self.infer_expr(cont)
329            }
330            Expr::Guard {
331                lhs,
332                rhs,
333                cont,
334                span: _,
335            } => {
336                let lhs_ty = self.infer_expr(lhs);
337                if let Some(rhs) = rhs {
338                    let rhs_ty = self.infer_expr(rhs);
339                    let rhs_span = rhs.get_span();
340                    self.unify(&lhs_ty, &rhs_ty, &rhs_span);
341                } else {
342                    let lhs_span = lhs.get_span();
343                    self.unify(
344                        &lhs_ty,
345                        &TermType::Cons(OptCons::None, Vec::new()),
346                        &lhs_span,
347                    );
348                }
349                self.infer_expr(cont)
350            }
351            Expr::Undefined { span: _ } => self.fresh(),
352        }
353    }
354
355    fn check_patn(&mut self, patn: &Pattern) -> TermType {
356        match patn {
357            Pattern::Lit { lit, span: _ } => TermType::Lit(lit.get_typ()),
358            Pattern::Var { var, span: _ } => {
359                let ty = self.fresh();
360                self.val_ctx.insert(var.ident, ty.clone());
361                ty
362            }
363            Pattern::Cons { cons, flds, span } => {
364                // instantiate constructor type scheme
365                let cons_scm = &self.cons_ctx[&cons.ident];
366
367                let inst_map: HashMap<Ident, TermType> = cons_scm
368                    .polys
369                    .iter()
370                    .map(|poly| (*poly, Term::Var(poly.uniquify())))
371                    .collect();
372
373                let inst_flds: Vec<_> = cons_scm
374                    .flds
375                    .iter()
376                    .map(|fld| fld.substitute(&inst_map))
377                    .collect();
378
379                let inst_res = cons_scm.res.substitute(&inst_map);
380
381                let flds: Vec<_> = flds
382                    .iter()
383                    .map(|fld| (self.check_patn(fld), fld.get_span()))
384                    .collect();
385
386                self.unify_many(&inst_flds, &flds, span);
387                inst_res
388            }
389            Pattern::Tuple { flds, span: _ } => {
390                let typs: Vec<TermType> = flds.iter().map(|fld| self.check_patn(fld)).collect();
391                TermType::Cons(OptCons::None, typs)
392            }
393        }
394    }
395
396    fn check_type(&mut self, typ: &Type) -> TermType {
397        match typ {
398            Type::Lit { lit, span: _ } => Term::Lit(*lit),
399            Type::Var { var, span: _ } => Term::Var(var.ident),
400            Type::Cons {
401                cons,
402                flds,
403                span: _,
404            } => {
405                let flds: Vec<_> = flds.iter().map(|fld| self.check_type(fld)).collect();
406                let data_scm = &self.data_ctx[&cons.ident];
407                if flds.len() != data_scm.polys.len() {
408                    self.errors.push(CheckError::TypeArityMismatch {
409                        actual: flds.len(),
410                        expected: data_scm.polys.len(),
411                        span: typ.get_span(),
412                    });
413                }
414                Term::Cons(OptCons::Some(cons.ident), flds)
415            }
416            Type::Tuple { flds, span: _ } => {
417                let flds: Vec<TermType> = flds.iter().map(|fld| self.check_type(fld)).collect();
418                Term::Cons(OptCons::None, flds)
419            }
420        }
421    }
422
423    fn scan_data_ty_scm(&mut self, data_decl: &DataDecl) {
424        for poly in &data_decl.polys {
425            self.unifier.fresh(poly.ident);
426        }
427        let data_scm = DataTyScm {
428            polys: data_decl.polys.iter().map(|poly| poly.ident).collect(),
429        };
430        self.data_ctx.insert(data_decl.name.ident, data_scm);
431    }
432
433    fn scan_cons_ty_scm(&mut self, data_decl: &DataDecl) {
434        let res = TermType::Cons(
435            OptCons::Some(data_decl.name.ident),
436            data_decl
437                .polys
438                .iter()
439                .map(|poly| TermType::Var(poly.ident))
440                .collect(),
441        );
442
443        for cons in &data_decl.cons {
444            let flds = cons.flds.iter().map(|fld| self.check_type(fld)).collect();
445            let cons_typ = ConsTyScm {
446                polys: data_decl.polys.iter().map(|poly| poly.ident).collect(),
447                flds,
448                res: res.clone(),
449            };
450            self.cons_ctx.insert(cons.name.ident, cons_typ);
451        }
452    }
453
454    fn scan_func_ty_scm(&mut self, func_decl: &FuncDecl) {
455        for poly in &func_decl.polys {
456            self.unifier.fresh(poly.ident);
457        }
458
459        let polys = func_decl.polys.iter().map(|poly| poly.ident).collect();
460        let pars = func_decl
461            .pars
462            .iter()
463            .map(|(_par, typ)| self.check_type(typ))
464            .collect();
465
466        let res = self.check_type(&func_decl.res);
467        let func_scm = FuncTyScm { polys, pars, res };
468        self.func_ctx.insert(func_decl.name.ident, func_scm);
469    }
470
471    fn check_func_decl(&mut self, func_decl: &FuncDecl) {
472        let func_scm = self.func_ctx[&func_decl.name.ident].clone();
473        for ((par, _), par_ty) in func_decl.pars.iter().zip(func_scm.pars.iter()) {
474            self.val_ctx.insert(par.ident, par_ty.clone());
475        }
476        let body_ty = self.infer_expr(&func_decl.body);
477        let body_span = func_decl.body.get_span();
478        self.unify(&func_scm.res, &body_ty, &body_span);
479    }
480
481    fn check_prog(&mut self, prog: &Program) {
482        for data_decl in &prog.datas {
483            self.scan_data_ty_scm(data_decl);
484        }
485
486        for data_decl in &prog.datas {
487            self.scan_cons_ty_scm(data_decl);
488        }
489
490        for func_decl in &prog.funcs {
491            self.scan_func_ty_scm(func_decl);
492        }
493
494        for func_decl in &prog.funcs {
495            self.check_func_decl(func_decl);
496        }
497    }
498}
499
500pub fn check_pass(prog: &Program) -> Vec<CheckError> {
501    let mut pass = Checker::new();
502    pass.check_prog(prog);
503    let mut errors = std::mem::take(&mut pass.errors);
504    for err in &mut errors {
505        match err {
506            CheckError::UnifyFailed {
507                typ1,
508                typ2,
509                span: _,
510            } => {
511                *typ1 = pass.unifier.subst(typ1);
512                *typ2 = pass.unifier.subst(typ2);
513            }
514            CheckError::OccurCheckFailed {
515                var: _,
516                typ,
517                span: _,
518            } => {
519                *typ = pass.unifier.subst(typ);
520            }
521            CheckError::UnifyVecDiffLen {
522                vec1,
523                vec2,
524                span: _,
525            } => {
526                *vec1 = vec1.iter().map(|t| pass.unifier.subst(t)).collect();
527                *vec2 = vec2.iter().map(|t| pass.unifier.subst(t)).collect();
528            }
529            CheckError::TypeArityMismatch {
530                actual: _,
531                expected: _,
532                span: _,
533            } => {
534                // do nothing
535            }
536        }
537    }
538    errors
539}
540
541#[test]
542#[ignore = "just to see result"]
543fn check_test() {
544    let src: &'static str = r#"
545datatype List[a] where
546| Cons(a, List[a])
547| Nil
548end
549
550function append[a](xs: List[a], x: a) -> List[a]
551begin
552    match xs with
553    | Cons(head, tail) => Cons(head, append(tail, x))
554    | Nil => Cons(x, Nil)
555    end
556end
557
558function is_elem(xs: List[Int], x: Int) -> Bool
559begin
560    match xs with
561    | Cons(head, tail) => if head == x then true else is_elem(tail, x) 
562    | Nil => false
563    end
564end
565
566function is_elem_after_append(xs: List[Int], x: Int)
567begin
568    guard is_elem(append(xs, x), x) = false;
569end
570
571query is_elem_after_append(depth_step=5, depth_limit=50, answer_limit=1)
572"#;
573    let (mut prog, errs) = crate::syntax::parser::parse_program(src);
574    assert!(errs.is_empty());
575
576    let errs = crate::tych::rename::rename_pass(&mut prog);
577    assert!(errs.is_empty());
578
579    // println!("{:#?}", prog);
580
581    let errs = check_pass(&prog);
582    assert!(errs.is_empty());
583
584    // println!("{:#?}", errs);
585    // println!("{:?}", map);
586
587    for err in errs {
588        let diag: Diagnostic = err.into();
589        println!("{}", diag.report(src, 10));
590    }
591
592    // println!("{:#?}", prog);
593    // println!("{:#?}", errs);
594}