Skip to main content

lust/typechecker/expr_checker/
entry.rs

1use super::*;
2use alloc::{boxed::Box, format, string::ToString, vec::Vec};
3use hashbrown::HashMap;
4use crate::builtins;
5impl TypeChecker {
6    pub fn check_expr(&mut self, expr: &Expr) -> Result<Type> {
7        let mut ty = self.check_expr_with_hint(expr, None)?;
8        if ty.span.start_line == 0 && expr.span.start_line > 0 {
9            ty.span = expr.span;
10        }
11
12        if expr.span.start_line > 0 {
13            if let Some(module) = &self.current_module {
14                self.expr_types_by_module
15                    .entry(module.clone())
16                    .or_default()
17                    .insert(expr.span, ty.clone());
18            }
19        }
20
21        Ok(ty)
22    }
23
24    pub fn check_expr_with_hint(
25        &mut self,
26        expr: &Expr,
27        expected_type: Option<&Type>,
28    ) -> Result<Type> {
29        match &expr.kind {
30            ExprKind::Literal(lit) => self.check_literal(lit),
31            ExprKind::Identifier(name) => {
32                if let Some(var_type) = self.env.lookup_variable(name) {
33                    return Ok(var_type);
34                }
35
36                let resolved_func = self.resolve_function_key(name);
37                if let Some(func_sig) = self.env.lookup_function(&resolved_func) {
38                    return Ok(Type::new(
39                        TypeKind::Function {
40                            params: func_sig.params.clone(),
41                            return_type: Box::new(func_sig.return_type.clone()),
42                        },
43                        expr.span,
44                    ));
45                }
46
47                if self.resolve_module_alias(name).is_some() {
48                    return Ok(Type::new(TypeKind::Unknown, expr.span));
49                }
50
51                let resolved = self.resolve_value_key(name);
52                if let Some(const_type) = self
53                    .env
54                    .lookup_constant(&resolved)
55                    .or_else(|| self.env.lookup_constant(name))
56                {
57                    return Ok(const_type);
58                }
59
60                if builtins::base_functions()
61                    .iter()
62                    .any(|builtin| builtin.name == name)
63                {
64                    return Ok(Type::new(TypeKind::Unknown, expr.span));
65                }
66
67                Ok(Type::new(TypeKind::Unknown, expr.span))
68            }
69            ExprKind::Binary { left, op, right } => {
70                self.check_binary_expr(expr.span, left, op, right)
71            }
72            ExprKind::Unary { op, operand } => self.check_unary_expr(op, operand),
73            ExprKind::Call { callee, args } => self.check_call_expr(expr.span, callee, args),
74            ExprKind::MethodCall {
75                receiver,
76                method,
77                type_args: _,
78                args,
79            } => self.check_method_call(receiver, method, args),
80
81            ExprKind::FieldAccess { object, field } => {
82                self.check_field_access_with_hint(expr.span, object, field, expected_type)
83            }
84
85            ExprKind::Index { object, index } => self.check_index_expr(object, index),
86            ExprKind::Array(elements) => self.check_array_literal(elements, expected_type),
87            ExprKind::Map(entries) => self.check_map_literal(entries, expected_type),
88            ExprKind::StructLiteral { name, fields } => {
89                self.check_struct_literal(expr.span, name, fields)
90            }
91
92            ExprKind::Lambda {
93                params,
94                return_type,
95                body,
96            } => self.check_lambda(params, return_type.as_ref(), body),
97            ExprKind::Cast { expr, target_type } => {
98                let _expr_type = self.check_expr(expr)?;
99                Ok(target_type.clone())
100            }
101
102            ExprKind::TypeCheck {
103                expr,
104                check_type: _,
105            } => {
106                let _expr_type = self.check_expr(expr)?;
107                Ok(Type::new(TypeKind::Bool, Self::dummy_span()))
108            }
109
110            ExprKind::IsPattern { expr, pattern } => {
111                let scrutinee_type = self.check_expr(expr)?;
112                self.validate_is_pattern(pattern, &scrutinee_type)?;
113                Ok(Type::new(TypeKind::Bool, Self::dummy_span()))
114            }
115
116            ExprKind::If {
117                condition,
118                then_branch,
119                else_branch,
120            } => self.check_if_expr(condition, then_branch, else_branch),
121            ExprKind::Block(stmts) => {
122                self.env.push_scope();
123                let mut result_type = Type::new(TypeKind::Unit, Self::dummy_span());
124                for stmt in stmts {
125                    match &stmt.kind {
126                        StmtKind::Expr(expr) => {
127                            result_type = self.check_expr(expr)?;
128                        }
129
130                        StmtKind::Return(values) => {
131                            result_type = if values.is_empty() {
132                                Type::new(TypeKind::Unit, Self::dummy_span())
133                            } else if values.len() == 1 {
134                                let expected = self.current_function_return_type.clone();
135                                let mut raw =
136                                    self.check_expr_with_hint(&values[0], expected.as_ref())?;
137                                if raw.span.start_line == 0 && values[0].span.start_line > 0 {
138                                    raw.span = values[0].span;
139                                }
140
141                                self.canonicalize_type(&raw)
142                            } else {
143                                let mut el_types = Vec::new();
144                                for value in values {
145                                    let raw = self.check_expr(value)?;
146                                    el_types.push(self.canonicalize_type(&raw));
147                                }
148
149                                Type::new(TypeKind::Tuple(el_types), Self::dummy_span())
150                            };
151                            self.pending_generic_instances.take();
152                            self.check_stmt(stmt)?;
153                        }
154
155                        _ => {
156                            self.check_stmt(stmt)?;
157                        }
158                    }
159                }
160
161                self.env.pop_scope();
162                if result_type.span.start_line == 0 {
163                    result_type.span = expr.span;
164                }
165
166                Ok(result_type)
167            }
168
169            ExprKind::Range { .. } => Err(self.type_error_at(
170                "Range expressions are not supported; use numeric for-loops".to_string(),
171                expr.span,
172            )),
173            ExprKind::EnumConstructor {
174                enum_name,
175                variant,
176                args,
177            } => {
178                let enum_def = self
179                    .env
180                    .lookup_enum(enum_name)
181                    .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", enum_name)))?
182                    .clone();
183                let variant_def = enum_def
184                    .variants
185                    .iter()
186                    .find(|v| v.name == *variant)
187                    .ok_or_else(|| {
188                        self.type_error(format!(
189                            "Enum '{}' has no variant '{}'",
190                            enum_name, variant
191                        ))
192                    })?;
193                if let Some(expected_fields) = &variant_def.fields {
194                    if args.len() != expected_fields.len() {
195                        return Err(self.type_error(format!(
196                            "Variant '{}::{}' expects {} arguments, got {}",
197                            enum_name,
198                            variant,
199                            expected_fields.len(),
200                            args.len()
201                        )));
202                    }
203
204                    let mut type_params = HashMap::new();
205                    for (arg, expected_type) in args.iter().zip(expected_fields.iter()) {
206                        let arg_type = self.check_expr(arg)?;
207                        if let TypeKind::Generic(type_param) = &expected_type.kind {
208                            type_params.insert(type_param.clone(), arg_type.clone());
209                        } else {
210                            self.unify(expected_type, &arg_type)?;
211                        }
212                    }
213
214                    if !type_params.is_empty() {
215                        self.pending_generic_instances = Some(type_params.clone());
216                    }
217
218                    if enum_name == "Option" {
219                        if let Some(inner_type) = type_params.get("T") {
220                            return Ok(Type::new(
221                                TypeKind::Option(Box::new(inner_type.clone())),
222                                Self::dummy_span(),
223                            ));
224                        }
225                    } else if enum_name == "Result" {
226                        if let (Some(ok_type), Some(err_type)) =
227                            (type_params.get("T"), type_params.get("E"))
228                        {
229                            return Ok(Type::new(
230                                TypeKind::Result(
231                                    Box::new(ok_type.clone()),
232                                    Box::new(err_type.clone()),
233                                ),
234                                Self::dummy_span(),
235                            ));
236                        }
237                    }
238                } else {
239                    if !args.is_empty() {
240                        return Err(self.type_error(format!(
241                            "Variant '{}::{}' is a unit variant and takes no arguments",
242                            enum_name, variant
243                        )));
244                    }
245                }
246
247                Ok(Type::new(
248                    TypeKind::Named(enum_name.clone()),
249                    Self::dummy_span(),
250                ))
251            }
252
253            ExprKind::Tuple(elements) => {
254                let expected_elements = expected_type.and_then(|ty| {
255                    if let TypeKind::Tuple(elems) = &ty.kind {
256                        Some(elems.clone())
257                    } else {
258                        None
259                    }
260                });
261                let mut element_types = Vec::new();
262                for (index, element) in elements.iter().enumerate() {
263                    let hint = expected_elements
264                        .as_ref()
265                        .and_then(|elems| elems.get(index));
266                    let mut raw_ty = if let Some(hint_ty) = hint {
267                        self.check_expr_with_hint(element, Some(hint_ty))?
268                    } else {
269                        self.check_expr(element)?
270                    };
271                    if raw_ty.span.start_line == 0 && element.span.start_line > 0 {
272                        raw_ty.span = element.span;
273                    }
274
275                    self.pending_generic_instances.take();
276                    element_types.push(self.canonicalize_type(&raw_ty));
277                }
278
279                Ok(Type::new(TypeKind::Tuple(element_types), expr.span))
280            }
281
282            ExprKind::Return(exprs) => {
283                let mut return_type = if exprs.is_empty() {
284                    Type::new(TypeKind::Unit, Self::dummy_span())
285                } else if exprs.len() == 1 {
286                    let expected = self.current_function_return_type.clone();
287                    let mut raw_ty = self.check_expr_with_hint(&exprs[0], expected.as_ref())?;
288                    if raw_ty.span.start_line == 0 && exprs[0].span.start_line > 0 {
289                        raw_ty.span = exprs[0].span;
290                    }
291
292                    self.pending_generic_instances.take();
293                    raw_ty
294                } else {
295                    let mut types = Vec::new();
296                    for value in exprs {
297                        let raw_ty = self.check_expr(value)?;
298                        let ty = self.canonicalize_type(&raw_ty);
299                        self.pending_generic_instances.take();
300                        types.push(ty);
301                    }
302
303                    Type::new(TypeKind::Tuple(types), Self::dummy_span())
304                };
305                if return_type.span.start_line == 0 {
306                    if let Some(first) = exprs.first() {
307                        return_type.span = first.span;
308                    } else {
309                        return_type.span = expr.span;
310                    }
311                }
312
313                if let Some(expected_return) = &self.current_function_return_type {
314                    self.unify(expected_return, &return_type)?;
315                } else {
316                    return Err(self.type_error("'return' outside of function".to_string()));
317                }
318
319                Ok(return_type)
320            }
321
322            ExprKind::Paren(inner) => self.check_expr(inner),
323        }
324    }
325}