1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use crate::lang::env::Env;
use crate::lang::env::EnvEntry::*;
use crate::lang::error::{AggregatedTypeError, TypeError};
use crate::lang::{ValueOrigin};
use crate::lang::UnionIndex;
use crate::lang::{PartialExpr, TcEnv};
use std::mem;

impl TcEnv {
    pub fn type_check(&mut self, root: UnionIndex) -> Result<UnionIndex, AggregatedTypeError> {
        let ti = self._type_check(root, &Env::new());

        let errors = mem::take(&mut self.errors);
        if errors.is_empty() {
            Ok(ti)
        } else {
            Err(AggregatedTypeError { errors })
        }
    }

    ///Invariant: Returned UnionIndex is valid in Env `s`
    pub(crate) fn _type_check(&mut self, i: UnionIndex, s: &Env) -> UnionIndex {
        // We should only type check values from the source code
        debug_assert!(matches!(self.value_origins[i.0], ValueOrigin::SourceCode(_)));
        
        let t = match self.values[i.0] {
            PartialExpr::Type => PartialExpr::Type,
            PartialExpr::Let(mut v, b) => {
                // Check `v`
                let err_count = self.errors.len();
                let vt = self._type_check(v, s);
                if self.errors.len() > err_count {
                    v = self.store(PartialExpr::Free, ValueOrigin::Failure);
                }

                let bt = self._type_check(b, &s.cons(CSubst(v, vt)));
                PartialExpr::Let(v, bt)
            }
            PartialExpr::DeBruijnIndex(index) => PartialExpr::Shift(
                match s.get(index) {
                    Some(&CType(_, t)) => t,
                    Some(&CSubst(_, t)) => t,
                    None => {
                        self.errors.push(TypeError::IndexOutOfBound(i));
                        self.store(PartialExpr::Free, ValueOrigin::Failure)
                    }
                    _ => unreachable!(),
                },
                index + 1,
            ),
            PartialExpr::FnType(mut a, b) => {
                let err_count = self.errors.len();
                let at = self._type_check(a, s);
                self.expect_beq_type(at, s);
                if self.errors.len() > err_count {
                    a = self.store(PartialExpr::Free, ValueOrigin::Failure);
                }

                let err_count = self.errors.len();
                let bs = s.cons(CType(self.new_tc_id(), a));
                let bt = self._type_check(b, &bs);
                if self.errors.len() == err_count {
                    self.expect_beq_type(bt, &bs);
                }

                PartialExpr::Type
            }
            PartialExpr::FnConstruct(mut a, b) => {
                let err_count = self.errors.len();
                let at = self._type_check(a, s);
                self.expect_beq_type(at, s);
                if self.errors.len() > err_count {
                    a = self.store(PartialExpr::Free, ValueOrigin::Failure);
                }

                let bs = s.cons(CType(self.new_tc_id(), a));
                let bt = self._type_check(b, &bs);
                PartialExpr::FnType(a, bt)
            }
            PartialExpr::FnDestruct(f, mut a) => {
                let err_count = self.errors.len();
                let at = self._type_check(a, s);
                if self.errors.len() > err_count {
                    a = self.store(PartialExpr::Free, ValueOrigin::Failure);
                };

                let rt = self.store(PartialExpr::Free, ValueOrigin::TypeOf(i));

                let err_count = self.errors.len();
                let ft = self._type_check(f, s);
                if self.errors.len() == err_count {
                    self.expect_beq_fn_type(ft, at, rt, s)
                }

                PartialExpr::Let(a, rt)
            }
            PartialExpr::Free => {
                let tid = self.store(PartialExpr::Free, ValueOrigin::TypeOf(i));
                // TODO self.queued_tc.insert(i, (s.clone(), t));
                self.value_types.insert(i, tid);
                return tid;
            }
            PartialExpr::Shift(..) => unreachable!(),
        };
        let tid = self.store(t, ValueOrigin::TypeOf(i));
        self.value_types.insert(i, tid);
        tid
    }
}