lust/typechecker/expr_checker/
patterns.rs

1use super::*;
2impl TypeChecker {
3    pub fn validate_is_pattern(&mut self, pattern: &Pattern, scrutinee_type: &Type) -> Result<()> {
4        match pattern {
5            Pattern::Wildcard | Pattern::Literal(_) | Pattern::Identifier(_) => Ok(()),
6            Pattern::TypeCheck(check_type) => {
7                let _ = check_type;
8                Ok(())
9            }
10
11            Pattern::Enum {
12                enum_name: _,
13                variant,
14                bindings,
15            } => {
16                let (type_name, variant_types) = match &scrutinee_type.kind {
17                    TypeKind::Named(name) => (name.clone(), None),
18                    TypeKind::Option(inner) => {
19                        ("Option".to_string(), Some(vec![(**inner).clone()]))
20                    }
21
22                    TypeKind::Result(ok, err) => (
23                        "Result".to_string(),
24                        Some(vec![(**ok).clone(), (**err).clone()]),
25                    ),
26                    TypeKind::Union(types) => {
27                        for ty in types.iter() {
28                            if let TypeKind::Named(name) = &ty.kind {
29                                if let Some(_) = {
30                                    let key = self.resolve_type_key(name);
31                                    self.env
32                                        .lookup_enum(&key)
33                                        .or_else(|| self.env.lookup_enum(name))
34                                } {
35                                    return Ok(());
36                                }
37                            }
38
39                            if matches!(ty.kind, TypeKind::Option(_) | TypeKind::Result(_, _)) {
40                                return Ok(());
41                            }
42                        }
43
44                        return Err(self.type_error(format!(
45                            "Union type '{}' does not contain enum types compatible with variant '{}'",
46                            scrutinee_type, variant
47                        )));
48                    }
49
50                    _ => {
51                        return Err(self.type_error(format!(
52                            "Cannot use enum pattern on non-enum type '{}'",
53                            scrutinee_type
54                        )))
55                    }
56                };
57                let enum_def = {
58                    let key = self.resolve_type_key(&type_name);
59                    self.env
60                        .lookup_enum(&key)
61                        .or_else(|| self.env.lookup_enum(&type_name))
62                }
63                .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", type_name)))?
64                .clone();
65                let variant_def = enum_def
66                    .variants
67                    .iter()
68                    .find(|v| &v.name == variant)
69                    .ok_or_else(|| {
70                        self.type_error(format!(
71                            "Enum '{}' has no variant '{}'",
72                            type_name, variant
73                        ))
74                    })?;
75                if let Some(variant_fields) = &variant_def.fields {
76                    if bindings.len() != variant_fields.len() {
77                        return Err(self.type_error(format!(
78                            "Variant '{}::{}' expects {} bindings, got {}",
79                            type_name,
80                            variant,
81                            variant_fields.len(),
82                            bindings.len()
83                        )));
84                    }
85
86                    for (binding, field_type) in bindings.iter().zip(variant_fields.iter()) {
87                        let bind_type = if let Some(ref types) = variant_types {
88                            if let TypeKind::Generic(_) = &field_type.kind {
89                                types.get(0).cloned().unwrap_or_else(|| field_type.clone())
90                            } else {
91                                field_type.clone()
92                            }
93                        } else {
94                            field_type.clone()
95                        };
96                        self.validate_is_pattern(binding, &bind_type)?;
97                    }
98                } else {
99                    if !bindings.is_empty() {
100                        return Err(self.type_error(format!(
101                            "Variant '{}::{}' is a unit variant and takes no bindings",
102                            type_name, variant
103                        )));
104                    }
105                }
106
107                Ok(())
108            }
109
110            Pattern::Struct { .. } => Ok(()),
111        }
112    }
113
114    pub fn extract_type_narrowings_from_expr(&mut self, expr: &Expr) -> Vec<(String, Type)> {
115        let mut narrowings = Vec::new();
116        match &expr.kind {
117            ExprKind::TypeCheck {
118                expr: scrutinee,
119                check_type: target_type,
120            } => {
121                if let ExprKind::Identifier(var_name) = &scrutinee.kind {
122                    if let Some(current_type) = self.env.lookup_variable(var_name) {
123                        let narrowed_type = if let TypeKind::Named(name) = &target_type.kind {
124                            let resolved = self.resolve_type_key(name);
125                            if self.env.lookup_trait(&resolved).is_some() {
126                                Type::new(TypeKind::Trait(name.clone()), target_type.span)
127                            } else {
128                                target_type.clone()
129                            }
130                        } else {
131                            target_type.clone()
132                        };
133                        match &current_type.kind {
134                            TypeKind::Unknown => {
135                                narrowings.push((var_name.clone(), narrowed_type));
136                            }
137
138                            TypeKind::Union(types) => {
139                                for ty in types {
140                                    if self.types_equal(ty, target_type) {
141                                        narrowings.push((var_name.clone(), target_type.clone()));
142                                        break;
143                                    }
144                                }
145                            }
146
147                            _ => {}
148                        }
149                    }
150                }
151            }
152
153            ExprKind::IsPattern {
154                expr: scrutinee,
155                pattern,
156            } => {
157                if let Pattern::TypeCheck(target_type) = pattern {
158                    if let ExprKind::Identifier(var_name) = &scrutinee.kind {
159                        if let Some(current_type) = self.env.lookup_variable(var_name) {
160                            match &current_type.kind {
161                                TypeKind::Unknown => {
162                                    narrowings.push((var_name.clone(), target_type.clone()));
163                                }
164
165                                TypeKind::Union(types) => {
166                                    for ty in types {
167                                        if self.types_equal(ty, target_type) {
168                                            narrowings
169                                                .push((var_name.clone(), target_type.clone()));
170                                            break;
171                                        }
172                                    }
173                                }
174
175                                _ => {}
176                            }
177                        }
178                    }
179                }
180            }
181
182            ExprKind::Binary { left, op, right } => {
183                if matches!(op, BinaryOp::And) {
184                    narrowings.extend(self.extract_type_narrowings_from_expr(left));
185                    narrowings.extend(self.extract_type_narrowings_from_expr(right));
186                }
187            }
188
189            _ => {}
190        }
191
192        narrowings
193    }
194
195    pub fn extract_all_pattern_bindings_from_expr<'a>(
196        &self,
197        expr: &'a Expr,
198    ) -> Vec<(&'a Expr, Pattern)> {
199        let mut bindings = Vec::new();
200        match &expr.kind {
201            ExprKind::IsPattern {
202                expr: scrutinee,
203                pattern,
204            } => match pattern {
205                Pattern::Enum {
206                    bindings: pattern_bindings,
207                    ..
208                } if !pattern_bindings.is_empty() => {
209                    bindings.push((scrutinee.as_ref(), pattern.clone()));
210                }
211
212                _ => {}
213            },
214            ExprKind::Binary { left, op, right } => {
215                if matches!(op, BinaryOp::And) {
216                    bindings.extend(self.extract_all_pattern_bindings_from_expr(left));
217                    bindings.extend(self.extract_all_pattern_bindings_from_expr(right));
218                }
219            }
220
221            _ => {}
222        }
223
224        bindings
225    }
226
227    pub fn bind_pattern(&mut self, pattern: &Pattern, scrutinee_type: &Type) -> Result<()> {
228        match pattern {
229            Pattern::Wildcard => Ok(()),
230            Pattern::Identifier(name) => self
231                .env
232                .declare_variable(name.clone(), scrutinee_type.clone()),
233            Pattern::Literal(_) => Ok(()),
234            Pattern::Struct { name: _, fields: _ } => Ok(()),
235            Pattern::Enum {
236                enum_name: _,
237                variant,
238                bindings,
239            } => {
240                let (type_name, variant_types) = match &scrutinee_type.kind {
241                    TypeKind::Named(name) => (name.clone(), None),
242                    TypeKind::Option(inner) => {
243                        ("Option".to_string(), Some(vec![(**inner).clone()]))
244                    }
245
246                    TypeKind::Result(ok, err) => (
247                        "Result".to_string(),
248                        Some(vec![(**ok).clone(), (**err).clone()]),
249                    ),
250                    _ => {
251                        return Err(self
252                            .type_error(format!("Expected enum type, got '{}'", scrutinee_type)))
253                    }
254                };
255                let enum_def = {
256                    let key = self.resolve_type_key(&type_name);
257                    self.env
258                        .lookup_enum(&key)
259                        .or_else(|| self.env.lookup_enum(&type_name))
260                }
261                .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", type_name)))?
262                .clone();
263                let variant_def = enum_def
264                    .variants
265                    .iter()
266                    .find(|v| &v.name == variant)
267                    .ok_or_else(|| {
268                        self.type_error(format!(
269                            "Enum '{}' has no variant '{}'",
270                            type_name, variant
271                        ))
272                    })?;
273                if let Some(variant_fields) = &variant_def.fields {
274                    if bindings.len() != variant_fields.len() {
275                        return Err(self.type_error(format!(
276                            "Variant '{}::{}' expects {} bindings, got {}",
277                            type_name,
278                            variant,
279                            variant_fields.len(),
280                            bindings.len()
281                        )));
282                    }
283
284                    for (i, (binding, field_type)) in
285                        bindings.iter().zip(variant_fields.iter()).enumerate()
286                    {
287                        let concrete =
288                            variant_types
289                                .as_ref()
290                                .and_then(|types| match type_name.as_str() {
291                                    "Option" => {
292                                        if variant == "Some" {
293                                            types.get(0).cloned()
294                                        } else {
295                                            None
296                                        }
297                                    }
298
299                                    "Result" => match variant.as_str() {
300                                        "Ok" => types.get(0).cloned(),
301                                        "Err" => types.get(1).cloned(),
302                                        _ => types.get(i).cloned(),
303                                    },
304                                    _ => types.get(i).cloned(),
305                                });
306                        let bind_type = if let Some(concrete_type) = concrete {
307                            concrete_type
308                        } else if matches!(field_type.kind, TypeKind::Generic(_)) {
309                            Type::new(TypeKind::Unknown, Self::dummy_span())
310                        } else {
311                            field_type.clone()
312                        };
313                        self.bind_pattern(binding, &bind_type)?;
314                    }
315                } else {
316                    if !bindings.is_empty() {
317                        return Err(self.type_error(format!(
318                            "Variant '{}::{}' is a unit variant and has no bindings",
319                            type_name, variant
320                        )));
321                    }
322                }
323
324                Ok(())
325            }
326
327            Pattern::TypeCheck(_) => Ok(()),
328        }
329    }
330}