lust/typechecker/expr_checker/
collections.rs

1use super::*;
2impl TypeChecker {
3    pub fn check_array_literal(
4        &mut self,
5        elements: &[Expr],
6        expected_type: Option<&Type>,
7    ) -> Result<Type> {
8        if elements.is_empty() {
9            if let Some(expected) = expected_type {
10                return Ok(expected.clone());
11            }
12
13            let span = Self::dummy_span();
14            return Ok(Type::new(
15                TypeKind::Array(Box::new(Type::new(TypeKind::Unknown, span))),
16                span,
17            ));
18        }
19
20        let expected_elem_type = expected_type.and_then(|t| {
21            if let TypeKind::Array(elem_type) = &t.kind {
22                Some(elem_type.as_ref())
23            } else {
24                None
25            }
26        });
27        if let Some(expected_elem) = expected_elem_type {
28            if let TypeKind::Union(union_types) = &expected_elem.kind {
29                for elem in elements {
30                    let elem_type = self.check_expr(elem)?;
31                    let mut matches = false;
32                    for union_variant in union_types {
33                        if self.types_equal(&elem_type, union_variant) {
34                            matches = true;
35                            break;
36                        }
37                    }
38
39                    if !matches {
40                        let union_desc = union_types
41                            .iter()
42                            .map(|t| t.to_string())
43                            .collect::<Vec<_>>()
44                            .join(" | ");
45                        return Err(self.type_error(format!(
46                            "Array element type '{}' does not match any type in union [{}]",
47                            elem_type, union_desc
48                        )));
49                    }
50                }
51
52                return Ok(expected_type.unwrap().clone());
53            }
54        }
55
56        if let Some(expected_elem) = expected_elem_type {
57            if matches!(expected_elem.kind, TypeKind::Unknown) {
58                for elem in elements {
59                    self.check_expr(elem)?;
60                }
61
62                return Ok(expected_type.unwrap().clone());
63            }
64
65            if let TypeKind::Option(inner) = &expected_elem.kind {
66                if matches!(inner.kind, TypeKind::Unknown) {
67                    for elem in elements {
68                        let elem_type = self.check_expr(elem)?;
69                        let is_option = matches!(&elem_type.kind, TypeKind::Option(_))
70                            || matches!(&elem_type.kind, TypeKind::Named(name) if name == "Option");
71                        if !is_option {
72                            return Err(self.type_error(format!(
73                                "Expected Option type for Array<Option<unknown>>, got '{}'",
74                                elem_type
75                            )));
76                        }
77                    }
78
79                    return Ok(expected_type.unwrap().clone());
80                }
81            }
82
83            if let TypeKind::Result(ok_inner, err_inner) = &expected_elem.kind {
84                if matches!(ok_inner.kind, TypeKind::Unknown)
85                    || matches!(err_inner.kind, TypeKind::Unknown)
86                {
87                    for elem in elements {
88                        let elem_type = self.check_expr(elem)?;
89                        let is_result = matches!(&elem_type.kind, TypeKind::Result(_, _))
90                            || matches!(&elem_type.kind, TypeKind::Named(name) if name == "Result");
91                        if !is_result {
92                            return Err(self.type_error(format!(
93                                "Expected Result type for Array<Result<unknown, ...>>, got '{}'",
94                                elem_type
95                            )));
96                        }
97                    }
98
99                    return Ok(expected_type.unwrap().clone());
100                }
101            }
102        }
103
104        let first_type = self.check_expr(&elements[0])?;
105        for elem in &elements[1..] {
106            let elem_type = self.check_expr(elem)?;
107            self.unify(&first_type, &elem_type)?;
108        }
109
110        Ok(Type::new(
111            TypeKind::Array(Box::new(first_type)),
112            Self::dummy_span(),
113        ))
114    }
115
116    pub fn check_map_literal(&mut self, entries: &[(Expr, Expr)]) -> Result<Type> {
117        if entries.is_empty() {
118            let span = Self::dummy_span();
119            return Ok(Type::new(
120                TypeKind::Map(
121                    Box::new(Type::new(TypeKind::Unknown, span)),
122                    Box::new(Type::new(TypeKind::Unknown, span)),
123                ),
124                span,
125            ));
126        }
127
128        let (first_key, first_value) = &entries[0];
129        let key_type = self.check_expr(first_key)?;
130        let value_type = self.check_expr(first_value)?;
131        if !self.env.type_implements_trait(&key_type, "Hashable") {
132            return Err(self.type_error(format!(
133                "Map key type '{}' must implement Hashable trait",
134                key_type
135            )));
136        }
137
138        for (key, value) in &entries[1..] {
139            let k_type = self.check_expr(key)?;
140            let v_type = self.check_expr(value)?;
141            self.unify(&key_type, &k_type)?;
142            self.unify(&value_type, &v_type)?;
143        }
144
145        Ok(Type::new(
146            TypeKind::Map(Box::new(key_type), Box::new(value_type)),
147            Self::dummy_span(),
148        ))
149    }
150
151    pub fn check_struct_literal(
152        &mut self,
153        span: Span,
154        name: &str,
155        fields: &[StructLiteralField],
156    ) -> Result<Type> {
157        let key = self.resolve_type_key(name);
158        let struct_def = self
159            .env
160            .lookup_struct(&key)
161            .or_else(|| self.env.lookup_struct(name))
162            .ok_or_else(|| self.type_error_at(format!("Undefined struct '{}'", name), span))?
163            .clone();
164        if fields.len() != struct_def.fields.len() {
165            return Err(self.type_error_at(
166                format!(
167                    "Struct '{}' has {} fields, but {} were provided",
168                    name,
169                    struct_def.fields.len(),
170                    fields.len()
171                ),
172                span,
173            ));
174        }
175
176        for field in fields {
177            let expected_type = struct_def
178                .fields
179                .iter()
180                .find(|f| f.name == field.name)
181                .map(|f| &f.ty)
182                .ok_or_else(|| {
183                    self.type_error_at(
184                        format!("Struct '{}' has no field '{}'", name, field.name),
185                        field.span,
186                    )
187                })?;
188            let actual_type = self.check_expr(&field.value)?;
189            match &expected_type.kind {
190                TypeKind::Option(inner_expected) => {
191                    if self.unify(inner_expected, &actual_type).is_err() {
192                        self.unify(expected_type, &actual_type)?;
193                    }
194                }
195
196                _ => {
197                    self.unify(expected_type, &actual_type)?;
198                }
199            }
200        }
201
202        let ty_name = if self.env.lookup_struct(&key).is_some() {
203            key
204        } else {
205            name.to_string()
206        };
207        Ok(Type::new(TypeKind::Named(ty_name), Self::dummy_span()))
208    }
209
210    pub fn check_lambda(
211        &mut self,
212        params: &[(String, Option<Type>)],
213        return_type: Option<&Type>,
214        body: &Expr,
215    ) -> Result<Type> {
216        self.env.push_scope();
217        let expected_signature = self.expected_lambda_signature.take();
218        let mut param_types = Vec::new();
219        for (i, (param_name, param_type)) in params.iter().enumerate() {
220            let ty = if let Some(explicit_type) = param_type {
221                explicit_type.clone()
222            } else if let Some((ref expected_params, _)) = expected_signature {
223                if i < expected_params.len() {
224                    expected_params[i].clone()
225                } else {
226                    Type::new(TypeKind::Infer, Self::dummy_span())
227                }
228            } else {
229                Type::new(TypeKind::Infer, Self::dummy_span())
230            };
231            self.env.declare_variable(param_name.clone(), ty.clone())?;
232            param_types.push(ty);
233        }
234
235        let saved_return_type = self.current_function_return_type.clone();
236        let inferred_return_type = if let Some(explicit) = return_type {
237            Some(explicit.clone())
238        } else if let Some((_, expected_ret)) = expected_signature {
239            expected_ret.or_else(|| Some(Type::new(TypeKind::Infer, Self::dummy_span())))
240        } else {
241            Some(Type::new(TypeKind::Infer, Self::dummy_span()))
242        };
243        self.current_function_return_type = inferred_return_type.clone();
244        let body_type = self.check_expr(body)?;
245        self.current_function_return_type = saved_return_type;
246        let actual_return_type = if let Some(expected) = return_type {
247            expected.clone()
248        } else if let Some(inferred) = &inferred_return_type {
249            if !matches!(inferred.kind, TypeKind::Infer) {
250                inferred.clone()
251            } else {
252                body_type
253            }
254        } else {
255            body_type
256        };
257        self.env.pop_scope();
258        Ok(Type::new(
259            TypeKind::Function {
260                params: param_types,
261                return_type: Box::new(actual_return_type),
262            },
263            Self::dummy_span(),
264        ))
265    }
266
267    pub fn check_if_expr(
268        &mut self,
269        condition: &Expr,
270        then_branch: &Expr,
271        else_branch: &Option<Box<Expr>>,
272    ) -> Result<Type> {
273        let cond_type = self.check_expr(condition)?;
274        self.unify(&Type::new(TypeKind::Bool, Self::dummy_span()), &cond_type)?;
275        let then_type = self.check_expr(then_branch)?;
276        if let Some(else_expr) = else_branch {
277            let else_type = self.check_expr(else_expr)?;
278            self.unify(&then_type, &else_type)?;
279            Ok(then_type)
280        } else {
281            Ok(Type::new(TypeKind::Unit, Self::dummy_span()))
282        }
283    }
284}