prism_compiler/lang/
type_check.rs

1use crate::lang::env::Env;
2use crate::lang::env::EnvEntry::*;
3use crate::lang::error::{AggregatedTypeError, TypeError};
4use crate::lang::UnionIndex;
5use crate::lang::ValueOrigin;
6use crate::lang::{PartialExpr, TcEnv};
7use std::mem;
8
9impl TcEnv {
10    pub fn type_check(&mut self, root: UnionIndex) -> Result<UnionIndex, AggregatedTypeError> {
11        let ti = self._type_check(root, &Env::new());
12
13        let errors = mem::take(&mut self.errors);
14        if errors.is_empty() {
15            Ok(ti)
16        } else {
17            Err(AggregatedTypeError { errors })
18        }
19    }
20
21    /// Type checkes `i` in scope `s`. Returns the type.
22    /// Invariant: Returned UnionIndex is valid in Env `s`
23    fn _type_check(&mut self, i: UnionIndex, s: &Env) -> UnionIndex {
24        // We should only type check values from the source code
25        assert!(matches!(self.value_origins[*i], ValueOrigin::SourceCode(_)));
26
27        let t = match self.values[*i] {
28            PartialExpr::Type => PartialExpr::Type,
29            PartialExpr::Let(mut v, b) => {
30                // Check `v`
31                let err_count = self.errors.len();
32                let vt = self._type_check(v, s);
33                if self.errors.len() > err_count {
34                    v = self.store(PartialExpr::Free, ValueOrigin::Failure);
35                }
36
37                let bt = self._type_check(b, &s.cons(CSubst(v, vt)));
38                PartialExpr::Let(v, bt)
39            }
40            PartialExpr::DeBruijnIndex(index) => match s.get(index) {
41                Some(&CType(_, t) | &CSubst(_, t)) => PartialExpr::Shift(t, index + 1),
42                Some(_) => unreachable!(),
43                None => {
44                    self.errors.push(TypeError::IndexOutOfBound(i));
45                    return self.store(PartialExpr::Free, ValueOrigin::Failure);
46                }
47            },
48            PartialExpr::FnType(mut a, b) => {
49                let err_count = self.errors.len();
50                let at = self._type_check(a, s);
51                self.expect_beq_type(at, s);
52                if self.errors.len() > err_count {
53                    a = self.store(PartialExpr::Free, ValueOrigin::Failure);
54                }
55
56                let err_count = self.errors.len();
57                let bs = s.cons(CType(self.new_tc_id(), a));
58                let bt = self._type_check(b, &bs);
59
60                // Check if `b` typechecked without errors.
61                if self.errors.len() == err_count {
62                    self.expect_beq_type(bt, &bs);
63                }
64
65                PartialExpr::Type
66            }
67            PartialExpr::FnConstruct(b) => {
68                let a = self.store(PartialExpr::Free, ValueOrigin::FreeSub(i));
69                let bs = s.cons(CType(self.new_tc_id(), a));
70                let bt = self._type_check(b, &bs);
71                PartialExpr::FnType(a, bt)
72            }
73            PartialExpr::FnDestruct(f, mut a) => {
74                let err_count = self.errors.len();
75                let at = self._type_check(a, s);
76                if self.errors.len() > err_count {
77                    a = self.store(PartialExpr::Free, ValueOrigin::Failure);
78                };
79
80                let rt = self.store(PartialExpr::Free, ValueOrigin::TypeOf(i));
81
82                let err_count = self.errors.len();
83                let ft = self._type_check(f, s);
84                if self.errors.len() == err_count {
85                    self.expect_beq_fn_type(ft, at, rt, s)
86                }
87
88                PartialExpr::Let(a, rt)
89            }
90            PartialExpr::Free => {
91                // TODO self.queued_tc.insert(i, (s.clone(), t));
92                PartialExpr::Free
93            }
94            PartialExpr::Shift(v, shift) => {
95                PartialExpr::Shift(self._type_check(v, &s.shift(shift)), shift)
96            }
97            PartialExpr::TypeAssert(e, typ) => {
98                let err_count1 = self.errors.len();
99                let et = self._type_check(e, s);
100
101                let err_count2 = self.errors.len();
102                let typt = self._type_check(typ, s);
103                if self.errors.len() == err_count2 {
104                    self.expect_beq_type(typt, s);
105                }
106
107                if self.errors.len() == err_count1 {
108                    self.expect_beq_assert(e, et, typ, s);
109                }
110
111                return et;
112            }
113        };
114        let tid = self.store(t, ValueOrigin::TypeOf(i));
115        self.value_types.insert(i, tid);
116        tid
117    }
118}