bitsy_lang/expr/
typecheck.rs

1use super::*;
2use crate::types::*;
3
4impl Expr {
5    #[allow(unused_variables)] // TODO remove this
6    pub fn typecheck(self: &Arc<Self>, type_expected: Type, ctx: Context<Path, Type>) -> Result<(), TypeError> {
7        if let Some(type_actual) = self.typeinfer(ctx.clone()) {
8            if type_actual == type_expected {
9                return Ok(());
10            } else {
11                return Err(TypeError::NotExpectedType(type_expected.clone(), type_actual.clone(), self.clone()));
12            }
13        }
14
15        let result = match (type_expected.clone(), &**self) {
16            (_type_expected, Expr::Reference(_loc, _typ, path)) => Err(TypeError::UndefinedReference(self.clone())),
17            (Type::Word(width_expected), Expr::Word(_loc, typ, width_actual, n)) => {
18                if let Some(width_actual) = width_actual {
19                    if *width_actual == width_expected {
20                        Err(TypeError::Other(self.clone(), format!("Not the expected width")))
21                    } else if n >> *width_actual != 0 {
22                        Err(TypeError::Other(self.clone(), format!("Doesn't fit")))
23                    } else {
24                        Ok(())
25                    }
26                } else {
27                    if n >> width_expected != 0 {
28                        Err(TypeError::Other(self.clone(), format!("Doesn't fit")))
29                    } else {
30                        Ok(())
31                    }
32                }
33            },
34            (_type_expected, Expr::Enum(_loc, _typ, typedef, _name)) => {
35                if type_expected == *typedef {
36                    Ok(())
37                } else {
38                    Err(TypeError::Other(self.clone(), format!("Type Error")))
39                }
40            },
41            (_type_expected, Expr::Ctor(loc, _typ, name, es)) => {
42                // TODO
43                if let Type::Valid(ref typ) = type_expected {
44                    if es.len() == 1 {
45                        es[0].typecheck(*typ.clone(), ctx.clone())
46                    } else if es.len() > 1 {
47                        Err(TypeError::Other(self.clone(), format!("Error")))
48                    } else {
49                        Ok(())
50                    }
51                } else {
52                    Err(TypeError::Other(self.clone(), format!("Not a Valid<T>: {self:?} is not {type_expected:?}")))
53                }
54            },
55            (Type::Struct(typedef), Expr::Struct(_loc, _typ, fields)) => {
56                // TODO Ensure all fields exist.
57                for (name, e) in fields {
58                    if let Some(typ) = typedef.type_of_field(name) {
59                        e.typecheck(typ, ctx.clone())?;
60                    } else {
61                        let typename = &typedef.name;
62                        return Err(TypeError::Other(self.clone(), format!("struct type {typename} has no field {name}")))
63                    }
64                }
65                Ok(())
66            },
67            (_type_expected, Expr::Let(_loc, _typ, name, e, b)) => {
68                if let Some(typ) = e.typeinfer(ctx.clone()) {
69                    b.typecheck(type_expected.clone(), ctx.extend(name.clone().into(), typ))
70                } else {
71                    Err(TypeError::Other(self.clone(), format!("Can infer type of {e:?} in let expression.")))
72                }
73            },
74            (_type_expected, Expr::Match(_loc, _typ, subject, arms)) => {
75                if let Some(subject_typ) = subject.typeinfer(ctx.clone()) {
76                    let invalid_arms: Vec<&MatchArm> = arms.into_iter().filter(|MatchArm(pat, _e)| !subject_typ.valid_pat(pat)).collect();
77                    if invalid_arms.len() > 0 {
78                        return Err(TypeError::Other(self.clone(), format!("Invalid patterns for {subject_typ:?}")));
79                    }
80                    // TODO check pattern linearity
81
82                    for MatchArm(pat, e) in arms {
83                        let new_ctx = subject_typ.extend_context_for_pat(ctx.clone(), pat);
84                        e.typecheck(type_expected.clone(), new_ctx)?;
85                    }
86                } else {
87                    return Err(TypeError::Other(self.clone(), format!("Match: Can't infer subject type")));
88                }
89                Ok(())
90            },
91            (_type_expected, Expr::UnOp(_loc, _typ, UnOp::Not, e)) => e.typecheck(type_expected.clone(), ctx.clone()),
92            (Type::Word(1), Expr::BinOp(_loc, _typ, BinOp::Eq | BinOp::Neq | BinOp::Lt, e1, e2)) => {
93                if let Some(typ1) = e1.typeinfer(ctx.clone()) {
94                    e2.typecheck(typ1, ctx.clone())?;
95                    Ok(())
96                } else {
97                    Err(TypeError::Other(self.clone(), format!("Can't infer type.")))
98                }
99            },
100            (Type::Word(n), Expr::BinOp(_loc, _typ, BinOp::Add | BinOp::Sub | BinOp::And | BinOp::Or | BinOp::Xor, e1, e2)) => {
101                e1.typecheck(type_expected.clone(), ctx.clone())?;
102                e2.typecheck(type_expected.clone(), ctx.clone())?;
103                Ok(())
104            },
105            (Type::Word(n), Expr::BinOp(_loc, _typ, BinOp::AddCarry, e1, e2)) => {
106                if let (Some(typ1), Some(typ2)) = (e1.typeinfer(ctx.clone()), e2.typeinfer(ctx.clone())) {
107                    if n > 0 && typ1 == typ2 && typ1 == Type::Word(n - 1) {
108                        Ok(())
109                    } else {
110                        Err(TypeError::Other(self.clone(), format!("Types don't match")))
111                    }
112                } else {
113                    Err(TypeError::Other(self.clone(), format!("Can't infer type.")))
114                }
115            },
116            (_type_expected, Expr::If(_loc, _typ, cond, e1, e2)) => {
117                cond.typecheck(Type::word(1), ctx.clone())?;
118                e1.typecheck(type_expected.clone(), ctx.clone())?;
119                e2.typecheck(type_expected.clone(), ctx.clone())?;
120                Ok(())
121            },
122            (_type_expected, Expr::Mux(_loc, _typ, cond, e1, e2)) => {
123                cond.typecheck(Type::word(1), ctx.clone())?;
124                e1.typecheck(type_expected.clone(), ctx.clone())?;
125                e2.typecheck(type_expected.clone(), ctx.clone())?;
126                Ok(())
127            },
128            (Type::Word(width_expected), Expr::Sext(_loc, typ, e)) => {
129                if let Some(type_actual) = e.typeinfer(ctx.clone()) {
130                    if let Type::Word(m) = type_actual {
131                        if width_expected >= m {
132                            Ok(())
133                        } else {
134                            Err(TypeError::Other(self.clone(), format!("Can't sext a Word<{m}> to a a Word<{width_expected}>")))
135                        }
136                    } else {
137                        Err(TypeError::Other(self.clone(), format!("Unknown?")))
138                    }
139                } else {
140                    Err(TypeError::CantInferType(self.clone()))
141                }
142            },
143            (Type::Word(n), Expr::ToWord(_loc, typ, e)) => {
144                let typ = e.typeinfer(ctx.clone()).unwrap();
145                if let Type::Enum(typedef) = typ {
146                    let width = typedef.bitwidth();
147                    if n == width {
148                        Ok(())
149                    } else {
150                        let name = &typedef.name;
151                        Err(TypeError::Other(self.clone(), format!("enum type {name} has bitwidth {width} which cannot be cast to Word<{n}>")))
152                    }
153                } else {
154                    unreachable!()
155                }
156            },
157            (Type::Vec(typ, n), Expr::Vec(_loc, _typ, es)) => {
158                for e in es {
159                    e.typecheck(*typ.clone(), ctx.clone())?;
160                }
161                if es.len() != n as usize {
162                    let type_actual = Type::vec(*typ.clone(), es.len().try_into().unwrap());
163                    Err(TypeError::NotExpectedType(type_expected.clone(), type_actual.clone(), self.clone()))
164                } else {
165                    Ok(())
166                }
167            },
168            (_type_expected, Expr::IdxField(_loc, _typ, e, field)) => {
169                // TODO probably want to infer idx exprs rather than check them.
170                match e.typeinfer(ctx.clone()) {
171                    Some(Type::Struct(typedef)) => {
172                        if let Some(type_actual) = typedef.type_of_field(field) {
173                            if type_expected == type_actual {
174                                Ok(())
175                            } else {
176                                return Err(TypeError::NotExpectedType(type_expected.clone(), type_actual.clone(), self.clone()));
177                            }
178                        } else {
179                            Err(TypeError::Other(self.clone(), format!("No such field: {field}")))
180                        }
181                    },
182                    Some(typ) => Err(TypeError::Other(self.clone(), format!("Expected struct type, not {typ:?}"))),
183                    None => Err(TypeError::Other(self.clone(), format!("Can't infer the type of {e:?}"))),
184                }
185            },
186            (_type_expected, Expr::Idx(_loc, _typ, e, i)) => {
187                match e.typeinfer(ctx.clone()) {
188                    Some(Type::Word(n)) if *i < n => Ok(()),
189                    Some(Type::Word(n)) => Err(TypeError::Other(self.clone(), format!("Index out of bounds"))),
190                    Some(typ) => Err(TypeError::Other(self.clone(), format!("Can't index into type {typ:?}"))),
191                    None => Err(TypeError::Other(self.clone(), format!("Can't infer the type of {e:?}"))),
192                }
193            },
194            (_type_expected, Expr::IdxRange(_loc, _typ, e, j, i)) => {
195                match e.typeinfer(ctx.clone()) {
196                    Some(Type::Word(n)) if n >= *j && j >= i => Ok(()),
197                    Some(Type::Word(_n)) => Err(TypeError::Other(self.clone(), format!("Index out of bounds"))),
198                    Some(typ) => Err(TypeError::Other(self.clone(), format!("Can't index into type {typ:?}"))),
199                    None => Err(TypeError::Other(self.clone(), format!("Can't infer the type of {e:?}"))),
200                }
201            },
202            (_type_expected, Expr::Call(_loc, _typ, fndef, es)) => {
203                // TODO
204                if fndef.args.len() != es.len() {
205                    let fn_name = &fndef.name;
206                    let m = fndef.args.len();
207                    let n = es.len();
208                    Err(TypeError::Other(self.clone(), format!("{fn_name} takes {n} args, but found {m} instead")))
209                } else {
210                    for ((_arg_name, arg_typ), e) in fndef.args.iter().zip(es.iter()) {
211                        e.typecheck(Type::clone(&arg_typ), ctx.clone())?;
212                    }
213
214                    Ok(())
215                }
216            },
217            (_type_expected, Expr::Hole(_loc, _typ, opt_name)) => Ok(()),
218            _ => Err(TypeError::Other(self.clone(), format!("{self:?} is not the expected type {type_expected:?}"))),
219        };
220
221        if let Some(typ) = self.type_of_cell() {
222            if let Ok(()) = &result {
223                let _ = typ.set(type_expected);
224            }
225        }
226        result
227    }
228
229    #[allow(unused_variables)] // TODO remove this
230    pub fn typeinfer(self: &Arc<Self>, ctx: Context<Path, Type>) -> Option<Type> {
231        let result = match &**self {
232            Expr::Reference(_loc, typ, path) => {
233                let type_actual = ctx.lookup(path)?;
234                Some(type_actual)
235            },
236            Expr::Net(_loc, _typ, netid) => panic!("Can't typecheck a net"),
237            Expr::Word(_loc, _typ, None, n) => None,
238            Expr::Word(_loc, _typ, Some(w), n) => if n >> w == 0 {
239                Some(Type::word(*w))
240            } else {
241                None
242            },
243            Expr::Enum(loc, _typ, typedef, _name) => {
244                Some(typedef.clone())
245            },
246            Expr::Cat(_loc, _typ, es) => {
247                let mut w = 0u64;
248                for e in es {
249                    if let Some(Type::Word(m)) = e.typeinfer(ctx.clone()) {
250                        w += m;
251                    } else {
252                        return None;
253                    }
254                }
255                Some(Type::word(w))
256            },
257            Expr::ToWord(loc, _typ, e) => {
258                match e.typeinfer(ctx.clone()) {
259                    Some(Type::Enum(typedef)) => {
260                        Some(Type::word(typedef.bitwidth()))
261                    }
262                    _ => None,
263                }
264            },
265            Expr::Vec(_loc, _typ, es) => None,
266            Expr::Idx(_loc, _typ, e, i) => {
267                match e.typeinfer(ctx.clone()) {
268                    Some(Type::Word(n)) if *i < n => Some(Type::word(1)),
269                    _ => None,
270                }
271            },
272            Expr::IdxRange(_loc, _typ, e, j, i) => {
273                match e.typeinfer(ctx.clone()) {
274                    Some(Type::Word(n)) if n >= *j && *j >= *i => Some(Type::word(*j - *i)),
275                    Some(Type::Word(n)) => None,
276                    Some(typ) => None,
277                    None => None,
278                }
279            },
280            Expr::Hole(_loc, _typ, opt_name) => None,
281            _ => None,
282        };
283
284        if let Some(type_actual) = &result {
285            if let Some(typ) = self.type_of_cell() {
286                let _ = typ.set(type_actual.clone());
287            }
288        }
289        result
290    }
291}
292
293impl Type {
294    fn valid_pat(&self, pat: &Pat) -> bool {
295        match pat {
296            Pat::At(ctor, subpats) => {
297                match &*self {
298                    Type::Valid(inner_type) => {
299                        if ctor == "Invalid" && subpats.len() == 0 {
300                            true
301                        } else if ctor == "Valid" && subpats.len() == 1 {
302                            inner_type.valid_pat(&subpats[0])
303                        } else {
304                            false
305                        }
306                    },
307                    Type::Enum(typedef) => {
308                        let alts: Vec<String> = typedef.values.iter().map(|(name, _val)| name.clone()).collect();
309                        alts.contains(ctor)
310                    },
311                    _ => false,
312                }
313            },
314            Pat::Bind(_x) => true,
315            Pat::Otherwise => true,
316        }
317    }
318
319    fn extend_context_for_pat(&self, ctx: Context<Path, Type>, pat: &Pat) -> Context<Path, Type> {
320        if let Type::Valid(inner_type) = self {
321            if let Pat::At(ctor, subpats) = pat {
322                if ctor == "Valid" && subpats.len() == 1 {
323                    if let Pat::Bind(x) = &subpats[0] {
324                        ctx.extend(x.clone().into(), *inner_type.clone())
325                    } else {
326                        unreachable!()
327                    }
328                } else if ctor == "Invalid" {
329                    ctx.clone()
330                } else {
331                    unreachable!()
332                }
333            } else {
334                ctx.clone()
335            }
336        } else {
337            ctx.clone()
338        }
339    }
340}