lust/typechecker/expr_checker/
patterns.rs1use super::*;
2impl TypeChecker {
3 pub fn validate_is_pattern(&mut self, pattern: &Pattern, scrutinee_type: &Type) -> Result<()> {
4 match pattern {
5 Pattern::Wildcard | Pattern::Literal(_) | Pattern::Identifier(_) => Ok(()),
6 Pattern::TypeCheck(check_type) => {
7 let _ = check_type;
8 Ok(())
9 }
10
11 Pattern::Enum {
12 enum_name: _,
13 variant,
14 bindings,
15 } => {
16 let (type_name, variant_types) = match &scrutinee_type.kind {
17 TypeKind::Named(name) => (name.clone(), None),
18 TypeKind::Option(inner) => {
19 ("Option".to_string(), Some(vec![(**inner).clone()]))
20 }
21
22 TypeKind::Result(ok, err) => (
23 "Result".to_string(),
24 Some(vec![(**ok).clone(), (**err).clone()]),
25 ),
26 TypeKind::Union(types) => {
27 for ty in types.iter() {
28 if let TypeKind::Named(name) = &ty.kind {
29 if let Some(_) = {
30 let key = self.resolve_type_key(name);
31 self.env
32 .lookup_enum(&key)
33 .or_else(|| self.env.lookup_enum(name))
34 } {
35 return Ok(());
36 }
37 }
38
39 if matches!(ty.kind, TypeKind::Option(_) | TypeKind::Result(_, _)) {
40 return Ok(());
41 }
42 }
43
44 return Err(self.type_error(format!(
45 "Union type '{}' does not contain enum types compatible with variant '{}'",
46 scrutinee_type, variant
47 )));
48 }
49
50 _ => {
51 return Err(self.type_error(format!(
52 "Cannot use enum pattern on non-enum type '{}'",
53 scrutinee_type
54 )))
55 }
56 };
57 let enum_def = {
58 let key = self.resolve_type_key(&type_name);
59 self.env
60 .lookup_enum(&key)
61 .or_else(|| self.env.lookup_enum(&type_name))
62 }
63 .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", type_name)))?
64 .clone();
65 let variant_def = enum_def
66 .variants
67 .iter()
68 .find(|v| &v.name == variant)
69 .ok_or_else(|| {
70 self.type_error(format!(
71 "Enum '{}' has no variant '{}'",
72 type_name, variant
73 ))
74 })?;
75 if let Some(variant_fields) = &variant_def.fields {
76 if bindings.len() != variant_fields.len() {
77 return Err(self.type_error(format!(
78 "Variant '{}::{}' expects {} bindings, got {}",
79 type_name,
80 variant,
81 variant_fields.len(),
82 bindings.len()
83 )));
84 }
85
86 for (binding, field_type) in bindings.iter().zip(variant_fields.iter()) {
87 let bind_type = if let Some(ref types) = variant_types {
88 if let TypeKind::Generic(_) = &field_type.kind {
89 types.get(0).cloned().unwrap_or_else(|| field_type.clone())
90 } else {
91 field_type.clone()
92 }
93 } else {
94 field_type.clone()
95 };
96 self.validate_is_pattern(binding, &bind_type)?;
97 }
98 } else {
99 if !bindings.is_empty() {
100 return Err(self.type_error(format!(
101 "Variant '{}::{}' is a unit variant and takes no bindings",
102 type_name, variant
103 )));
104 }
105 }
106
107 Ok(())
108 }
109
110 Pattern::Struct { .. } => Ok(()),
111 }
112 }
113
114 pub fn extract_type_narrowings_from_expr(&mut self, expr: &Expr) -> Vec<(String, Type)> {
115 let mut narrowings = Vec::new();
116 match &expr.kind {
117 ExprKind::TypeCheck {
118 expr: scrutinee,
119 check_type: target_type,
120 } => {
121 if let ExprKind::Identifier(var_name) = &scrutinee.kind {
122 if let Some(current_type) = self.env.lookup_variable(var_name) {
123 let narrowed_type = if let TypeKind::Named(name) = &target_type.kind {
124 let resolved = self.resolve_type_key(name);
125 if self.env.lookup_trait(&resolved).is_some() {
126 Type::new(TypeKind::Trait(name.clone()), target_type.span)
127 } else {
128 target_type.clone()
129 }
130 } else {
131 target_type.clone()
132 };
133 match ¤t_type.kind {
134 TypeKind::Unknown => {
135 narrowings.push((var_name.clone(), narrowed_type));
136 }
137
138 TypeKind::Union(types) => {
139 for ty in types {
140 if self.types_equal(ty, target_type) {
141 narrowings.push((var_name.clone(), target_type.clone()));
142 break;
143 }
144 }
145 }
146
147 _ => {}
148 }
149 }
150 }
151 }
152
153 ExprKind::IsPattern {
154 expr: scrutinee,
155 pattern,
156 } => {
157 if let Pattern::TypeCheck(target_type) = pattern {
158 if let ExprKind::Identifier(var_name) = &scrutinee.kind {
159 if let Some(current_type) = self.env.lookup_variable(var_name) {
160 match ¤t_type.kind {
161 TypeKind::Unknown => {
162 narrowings.push((var_name.clone(), target_type.clone()));
163 }
164
165 TypeKind::Union(types) => {
166 for ty in types {
167 if self.types_equal(ty, target_type) {
168 narrowings
169 .push((var_name.clone(), target_type.clone()));
170 break;
171 }
172 }
173 }
174
175 _ => {}
176 }
177 }
178 }
179 }
180 }
181
182 ExprKind::Binary { left, op, right } => {
183 if matches!(op, BinaryOp::And) {
184 narrowings.extend(self.extract_type_narrowings_from_expr(left));
185 narrowings.extend(self.extract_type_narrowings_from_expr(right));
186 }
187 }
188
189 _ => {}
190 }
191
192 narrowings
193 }
194
195 pub fn extract_all_pattern_bindings_from_expr<'a>(
196 &self,
197 expr: &'a Expr,
198 ) -> Vec<(&'a Expr, Pattern)> {
199 let mut bindings = Vec::new();
200 match &expr.kind {
201 ExprKind::IsPattern {
202 expr: scrutinee,
203 pattern,
204 } => match pattern {
205 Pattern::Enum {
206 bindings: pattern_bindings,
207 ..
208 } if !pattern_bindings.is_empty() => {
209 bindings.push((scrutinee.as_ref(), pattern.clone()));
210 }
211
212 _ => {}
213 },
214 ExprKind::Binary { left, op, right } => {
215 if matches!(op, BinaryOp::And) {
216 bindings.extend(self.extract_all_pattern_bindings_from_expr(left));
217 bindings.extend(self.extract_all_pattern_bindings_from_expr(right));
218 }
219 }
220
221 _ => {}
222 }
223
224 bindings
225 }
226
227 pub fn bind_pattern(&mut self, pattern: &Pattern, scrutinee_type: &Type) -> Result<()> {
228 match pattern {
229 Pattern::Wildcard => Ok(()),
230 Pattern::Identifier(name) => self
231 .env
232 .declare_variable(name.clone(), scrutinee_type.clone()),
233 Pattern::Literal(_) => Ok(()),
234 Pattern::Struct { name: _, fields: _ } => Ok(()),
235 Pattern::Enum {
236 enum_name: _,
237 variant,
238 bindings,
239 } => {
240 let (type_name, variant_types) = match &scrutinee_type.kind {
241 TypeKind::Named(name) => (name.clone(), None),
242 TypeKind::Option(inner) => {
243 ("Option".to_string(), Some(vec![(**inner).clone()]))
244 }
245
246 TypeKind::Result(ok, err) => (
247 "Result".to_string(),
248 Some(vec![(**ok).clone(), (**err).clone()]),
249 ),
250 _ => {
251 return Err(self
252 .type_error(format!("Expected enum type, got '{}'", scrutinee_type)))
253 }
254 };
255 let enum_def = {
256 let key = self.resolve_type_key(&type_name);
257 self.env
258 .lookup_enum(&key)
259 .or_else(|| self.env.lookup_enum(&type_name))
260 }
261 .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", type_name)))?
262 .clone();
263 let variant_def = enum_def
264 .variants
265 .iter()
266 .find(|v| &v.name == variant)
267 .ok_or_else(|| {
268 self.type_error(format!(
269 "Enum '{}' has no variant '{}'",
270 type_name, variant
271 ))
272 })?;
273 if let Some(variant_fields) = &variant_def.fields {
274 if bindings.len() != variant_fields.len() {
275 return Err(self.type_error(format!(
276 "Variant '{}::{}' expects {} bindings, got {}",
277 type_name,
278 variant,
279 variant_fields.len(),
280 bindings.len()
281 )));
282 }
283
284 for (i, (binding, field_type)) in
285 bindings.iter().zip(variant_fields.iter()).enumerate()
286 {
287 let concrete =
288 variant_types
289 .as_ref()
290 .and_then(|types| match type_name.as_str() {
291 "Option" => {
292 if variant == "Some" {
293 types.get(0).cloned()
294 } else {
295 None
296 }
297 }
298
299 "Result" => match variant.as_str() {
300 "Ok" => types.get(0).cloned(),
301 "Err" => types.get(1).cloned(),
302 _ => types.get(i).cloned(),
303 },
304 _ => types.get(i).cloned(),
305 });
306 let bind_type = if let Some(concrete_type) = concrete {
307 concrete_type
308 } else if matches!(field_type.kind, TypeKind::Generic(_)) {
309 Type::new(TypeKind::Unknown, Self::dummy_span())
310 } else {
311 field_type.clone()
312 };
313 self.bind_pattern(binding, &bind_type)?;
314 }
315 } else {
316 if !bindings.is_empty() {
317 return Err(self.type_error(format!(
318 "Variant '{}::{}' is a unit variant and has no bindings",
319 type_name, variant
320 )));
321 }
322 }
323
324 Ok(())
325 }
326
327 Pattern::TypeCheck(_) => Ok(()),
328 }
329 }
330}