prism_compiler/lang/
type_check.rs1use crate::lang::env::Env;
2use crate::lang::env::EnvEntry::*;
3use crate::lang::error::{AggregatedTypeError, TypeError};
4use crate::lang::UnionIndex;
5use crate::lang::ValueOrigin;
6use crate::lang::{PartialExpr, TcEnv};
7use std::mem;
8
9impl TcEnv {
10 pub fn type_check(&mut self, root: UnionIndex) -> Result<UnionIndex, AggregatedTypeError> {
11 let ti = self._type_check(root, &Env::new());
12
13 let errors = mem::take(&mut self.errors);
14 if errors.is_empty() {
15 Ok(ti)
16 } else {
17 Err(AggregatedTypeError { errors })
18 }
19 }
20
21 fn _type_check(&mut self, i: UnionIndex, s: &Env) -> UnionIndex {
24 assert!(matches!(self.value_origins[*i], ValueOrigin::SourceCode(_)));
26
27 let t = match self.values[*i] {
28 PartialExpr::Type => PartialExpr::Type,
29 PartialExpr::Let(mut v, b) => {
30 let err_count = self.errors.len();
32 let vt = self._type_check(v, s);
33 if self.errors.len() > err_count {
34 v = self.store(PartialExpr::Free, ValueOrigin::Failure);
35 }
36
37 let bt = self._type_check(b, &s.cons(CSubst(v, vt)));
38 PartialExpr::Let(v, bt)
39 }
40 PartialExpr::DeBruijnIndex(index) => match s.get(index) {
41 Some(&CType(_, t) | &CSubst(_, t)) => PartialExpr::Shift(t, index + 1),
42 Some(_) => unreachable!(),
43 None => {
44 self.errors.push(TypeError::IndexOutOfBound(i));
45 return self.store(PartialExpr::Free, ValueOrigin::Failure);
46 }
47 },
48 PartialExpr::FnType(mut a, b) => {
49 let err_count = self.errors.len();
50 let at = self._type_check(a, s);
51 self.expect_beq_type(at, s);
52 if self.errors.len() > err_count {
53 a = self.store(PartialExpr::Free, ValueOrigin::Failure);
54 }
55
56 let err_count = self.errors.len();
57 let bs = s.cons(CType(self.new_tc_id(), a));
58 let bt = self._type_check(b, &bs);
59
60 if self.errors.len() == err_count {
62 self.expect_beq_type(bt, &bs);
63 }
64
65 PartialExpr::Type
66 }
67 PartialExpr::FnConstruct(b) => {
68 let a = self.store(PartialExpr::Free, ValueOrigin::FreeSub(i));
69 let bs = s.cons(CType(self.new_tc_id(), a));
70 let bt = self._type_check(b, &bs);
71 PartialExpr::FnType(a, bt)
72 }
73 PartialExpr::FnDestruct(f, mut a) => {
74 let err_count = self.errors.len();
75 let at = self._type_check(a, s);
76 if self.errors.len() > err_count {
77 a = self.store(PartialExpr::Free, ValueOrigin::Failure);
78 };
79
80 let rt = self.store(PartialExpr::Free, ValueOrigin::TypeOf(i));
81
82 let err_count = self.errors.len();
83 let ft = self._type_check(f, s);
84 if self.errors.len() == err_count {
85 self.expect_beq_fn_type(ft, at, rt, s)
86 }
87
88 PartialExpr::Let(a, rt)
89 }
90 PartialExpr::Free => {
91 PartialExpr::Free
93 }
94 PartialExpr::Shift(v, shift) => {
95 PartialExpr::Shift(self._type_check(v, &s.shift(shift)), shift)
96 }
97 PartialExpr::TypeAssert(e, typ) => {
98 let err_count1 = self.errors.len();
99 let et = self._type_check(e, s);
100
101 let err_count2 = self.errors.len();
102 let typt = self._type_check(typ, s);
103 if self.errors.len() == err_count2 {
104 self.expect_beq_type(typt, s);
105 }
106
107 if self.errors.len() == err_count1 {
108 self.expect_beq_assert(e, et, typ, s);
109 }
110
111 return et;
112 }
113 };
114 let tid = self.store(t, ValueOrigin::TypeOf(i));
115 self.value_types.insert(i, tid);
116 tid
117 }
118}