lust/typechecker/expr_checker/
patterns.rs

1use super::*;
2use alloc::{
3    format,
4    string::{String, ToString},
5    vec,
6    vec::Vec,
7};
8impl TypeChecker {
9    pub fn validate_is_pattern(&mut self, pattern: &Pattern, scrutinee_type: &Type) -> Result<()> {
10        match pattern {
11            Pattern::Wildcard | Pattern::Literal(_) | Pattern::Identifier(_) => Ok(()),
12            Pattern::TypeCheck(check_type) => {
13                let _ = check_type;
14                Ok(())
15            }
16
17            Pattern::Enum {
18                enum_name: _,
19                variant,
20                bindings,
21            } => {
22                let (type_name, variant_types) = match &scrutinee_type.kind {
23                    TypeKind::Named(name) => (name.clone(), None),
24                    TypeKind::Option(inner) => {
25                        ("Option".to_string(), Some(vec![(**inner).clone()]))
26                    }
27
28                    TypeKind::Result(ok, err) => (
29                        "Result".to_string(),
30                        Some(vec![(**ok).clone(), (**err).clone()]),
31                    ),
32                    TypeKind::Union(types) => {
33                        for ty in types.iter() {
34                            if let TypeKind::Named(name) = &ty.kind {
35                                if let Some(_) = {
36                                    let key = self.resolve_type_key(name);
37                                    self.env
38                                        .lookup_enum(&key)
39                                        .or_else(|| self.env.lookup_enum(name))
40                                } {
41                                    return Ok(());
42                                }
43                            }
44
45                            if matches!(ty.kind, TypeKind::Option(_) | TypeKind::Result(_, _)) {
46                                return Ok(());
47                            }
48                        }
49
50                        return Err(self.type_error(format!(
51                            "Union type '{}' does not contain enum types compatible with variant '{}'",
52                            scrutinee_type, variant
53                        )));
54                    }
55
56                    _ => {
57                        return Err(self.type_error(format!(
58                            "Cannot use enum pattern on non-enum type '{}'",
59                            scrutinee_type
60                        )))
61                    }
62                };
63                let enum_def = {
64                    let key = self.resolve_type_key(&type_name);
65                    self.env
66                        .lookup_enum(&key)
67                        .or_else(|| self.env.lookup_enum(&type_name))
68                }
69                .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", type_name)))?
70                .clone();
71                let variant_def = enum_def
72                    .variants
73                    .iter()
74                    .find(|v| &v.name == variant)
75                    .ok_or_else(|| {
76                        self.type_error(format!(
77                            "Enum '{}' has no variant '{}'",
78                            type_name, variant
79                        ))
80                    })?;
81                if let Some(variant_fields) = &variant_def.fields {
82                    if bindings.len() != variant_fields.len() {
83                        return Err(self.type_error(format!(
84                            "Variant '{}::{}' expects {} bindings, got {}",
85                            type_name,
86                            variant,
87                            variant_fields.len(),
88                            bindings.len()
89                        )));
90                    }
91
92                    for (binding, field_type) in bindings.iter().zip(variant_fields.iter()) {
93                        let bind_type = if let Some(ref types) = variant_types {
94                            if let TypeKind::Generic(_) = &field_type.kind {
95                                types.get(0).cloned().unwrap_or_else(|| field_type.clone())
96                            } else {
97                                field_type.clone()
98                            }
99                        } else {
100                            field_type.clone()
101                        };
102                        self.validate_is_pattern(binding, &bind_type)?;
103                    }
104                } else {
105                    if !bindings.is_empty() {
106                        return Err(self.type_error(format!(
107                            "Variant '{}::{}' is a unit variant and takes no bindings",
108                            type_name, variant
109                        )));
110                    }
111                }
112
113                Ok(())
114            }
115
116            Pattern::Struct { .. } => Ok(()),
117        }
118    }
119
120    pub fn extract_type_narrowings_from_expr(&mut self, expr: &Expr) -> Vec<(String, Type)> {
121        let mut narrowings = Vec::new();
122        match &expr.kind {
123            ExprKind::TypeCheck {
124                expr: scrutinee,
125                check_type: target_type,
126            } => {
127                if let ExprKind::Identifier(var_name) = &scrutinee.kind {
128                    if let Some(current_type) = self.env.lookup_variable(var_name) {
129                        let narrowed_type = if let TypeKind::Named(name) = &target_type.kind {
130                            let resolved = self.resolve_type_key(name);
131                            if self.env.lookup_trait(&resolved).is_some() {
132                                Type::new(TypeKind::Trait(name.clone()), target_type.span)
133                            } else {
134                                target_type.clone()
135                            }
136                        } else {
137                            target_type.clone()
138                        };
139                        match &current_type.kind {
140                            TypeKind::Unknown => {
141                                narrowings.push((var_name.clone(), narrowed_type));
142                            }
143
144                            TypeKind::Union(types) => {
145                                for ty in types {
146                                    if self.types_equal(ty, target_type) {
147                                        narrowings.push((var_name.clone(), target_type.clone()));
148                                        break;
149                                    }
150                                }
151                            }
152
153                            _ => {}
154                        }
155                    }
156                }
157            }
158
159            ExprKind::IsPattern {
160                expr: scrutinee,
161                pattern,
162            } => {
163                if let Pattern::TypeCheck(target_type) = pattern {
164                    if let ExprKind::Identifier(var_name) = &scrutinee.kind {
165                        if let Some(current_type) = self.env.lookup_variable(var_name) {
166                            match &current_type.kind {
167                                TypeKind::Unknown => {
168                                    narrowings.push((var_name.clone(), target_type.clone()));
169                                }
170
171                                TypeKind::Union(types) => {
172                                    for ty in types {
173                                        if self.types_equal(ty, target_type) {
174                                            narrowings
175                                                .push((var_name.clone(), target_type.clone()));
176                                            break;
177                                        }
178                                    }
179                                }
180
181                                _ => {}
182                            }
183                        }
184                    }
185                }
186            }
187
188            ExprKind::Binary { left, op, right } => {
189                if matches!(op, BinaryOp::And) {
190                    narrowings.extend(self.extract_type_narrowings_from_expr(left));
191                    narrowings.extend(self.extract_type_narrowings_from_expr(right));
192                }
193            }
194
195            ExprKind::Paren(inner) => {
196                narrowings.extend(self.extract_type_narrowings_from_expr(inner));
197            }
198
199            _ => {}
200        }
201
202        narrowings
203    }
204
205    pub fn extract_all_pattern_bindings_from_expr<'a>(
206        &self,
207        expr: &'a Expr,
208    ) -> Vec<(&'a Expr, Pattern)> {
209        let mut bindings = Vec::new();
210        match &expr.kind {
211            ExprKind::IsPattern {
212                expr: scrutinee,
213                pattern,
214            } => match pattern {
215                Pattern::Enum {
216                    bindings: pattern_bindings,
217                    ..
218                } if !pattern_bindings.is_empty() => {
219                    bindings.push((scrutinee.as_ref(), pattern.clone()));
220                }
221
222                _ => {}
223            },
224            ExprKind::Binary { left, op, right } => {
225                if matches!(op, BinaryOp::And) {
226                    bindings.extend(self.extract_all_pattern_bindings_from_expr(left));
227                    bindings.extend(self.extract_all_pattern_bindings_from_expr(right));
228                }
229            }
230
231            ExprKind::Paren(inner) => {
232                bindings.extend(self.extract_all_pattern_bindings_from_expr(inner));
233            }
234
235            _ => {}
236        }
237
238        bindings
239    }
240
241    pub fn bind_pattern(&mut self, pattern: &Pattern, scrutinee_type: &Type) -> Result<()> {
242        match pattern {
243            Pattern::Wildcard => Ok(()),
244            Pattern::Identifier(name) => self
245                .env
246                .declare_variable(name.clone(), scrutinee_type.clone()),
247            Pattern::Literal(_) => Ok(()),
248            Pattern::Struct { name: _, fields: _ } => Ok(()),
249            Pattern::Enum {
250                enum_name: _,
251                variant,
252                bindings,
253            } => {
254                let (type_name, variant_types) = match &scrutinee_type.kind {
255                    TypeKind::Named(name) => (name.clone(), None),
256                    TypeKind::Option(inner) => {
257                        ("Option".to_string(), Some(vec![(**inner).clone()]))
258                    }
259
260                    TypeKind::Result(ok, err) => (
261                        "Result".to_string(),
262                        Some(vec![(**ok).clone(), (**err).clone()]),
263                    ),
264                    _ => {
265                        return Err(self
266                            .type_error(format!("Expected enum type, got '{}'", scrutinee_type)))
267                    }
268                };
269                let enum_def = {
270                    let key = self.resolve_type_key(&type_name);
271                    self.env
272                        .lookup_enum(&key)
273                        .or_else(|| self.env.lookup_enum(&type_name))
274                }
275                .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", type_name)))?
276                .clone();
277                let variant_def = enum_def
278                    .variants
279                    .iter()
280                    .find(|v| &v.name == variant)
281                    .ok_or_else(|| {
282                        self.type_error(format!(
283                            "Enum '{}' has no variant '{}'",
284                            type_name, variant
285                        ))
286                    })?;
287                if let Some(variant_fields) = &variant_def.fields {
288                    if bindings.len() != variant_fields.len() {
289                        return Err(self.type_error(format!(
290                            "Variant '{}::{}' expects {} bindings, got {}",
291                            type_name,
292                            variant,
293                            variant_fields.len(),
294                            bindings.len()
295                        )));
296                    }
297
298                    for (i, (binding, field_type)) in
299                        bindings.iter().zip(variant_fields.iter()).enumerate()
300                    {
301                        let concrete =
302                            variant_types
303                                .as_ref()
304                                .and_then(|types| match type_name.as_str() {
305                                    "Option" => {
306                                        if variant == "Some" {
307                                            types.get(0).cloned()
308                                        } else {
309                                            None
310                                        }
311                                    }
312
313                                    "Result" => match variant.as_str() {
314                                        "Ok" => types.get(0).cloned(),
315                                        "Err" => types.get(1).cloned(),
316                                        _ => types.get(i).cloned(),
317                                    },
318                                    _ => types.get(i).cloned(),
319                                });
320                        let bind_type = if let Some(concrete_type) = concrete {
321                            concrete_type
322                        } else if matches!(field_type.kind, TypeKind::Generic(_)) {
323                            Type::new(TypeKind::Unknown, Self::dummy_span())
324                        } else {
325                            field_type.clone()
326                        };
327                        self.bind_pattern(binding, &bind_type)?;
328                    }
329                } else {
330                    if !bindings.is_empty() {
331                        return Err(self.type_error(format!(
332                            "Variant '{}::{}' is a unit variant and has no bindings",
333                            type_name, variant
334                        )));
335                    }
336                }
337
338                Ok(())
339            }
340
341            Pattern::TypeCheck(_) => Ok(()),
342        }
343    }
344}