lust/typechecker/expr_checker/
patterns.rs1use super::*;
2use alloc::{
3 format,
4 string::{String, ToString},
5 vec,
6 vec::Vec,
7};
8impl TypeChecker {
9 pub fn validate_is_pattern(&mut self, pattern: &Pattern, scrutinee_type: &Type) -> Result<()> {
10 match pattern {
11 Pattern::Wildcard | Pattern::Literal(_) | Pattern::Identifier(_) => Ok(()),
12 Pattern::TypeCheck(check_type) => {
13 let _ = check_type;
14 Ok(())
15 }
16
17 Pattern::Enum {
18 enum_name: _,
19 variant,
20 bindings,
21 } => {
22 let (type_name, variant_types) = match &scrutinee_type.kind {
23 TypeKind::Named(name) => (name.clone(), None),
24 TypeKind::Option(inner) => {
25 ("Option".to_string(), Some(vec![(**inner).clone()]))
26 }
27
28 TypeKind::Result(ok, err) => (
29 "Result".to_string(),
30 Some(vec![(**ok).clone(), (**err).clone()]),
31 ),
32 TypeKind::Union(types) => {
33 for ty in types.iter() {
34 if let TypeKind::Named(name) = &ty.kind {
35 if let Some(_) = {
36 let key = self.resolve_type_key(name);
37 self.env
38 .lookup_enum(&key)
39 .or_else(|| self.env.lookup_enum(name))
40 } {
41 return Ok(());
42 }
43 }
44
45 if matches!(ty.kind, TypeKind::Option(_) | TypeKind::Result(_, _)) {
46 return Ok(());
47 }
48 }
49
50 return Err(self.type_error(format!(
51 "Union type '{}' does not contain enum types compatible with variant '{}'",
52 scrutinee_type, variant
53 )));
54 }
55
56 _ => {
57 return Err(self.type_error(format!(
58 "Cannot use enum pattern on non-enum type '{}'",
59 scrutinee_type
60 )))
61 }
62 };
63 let enum_def = {
64 let key = self.resolve_type_key(&type_name);
65 self.env
66 .lookup_enum(&key)
67 .or_else(|| self.env.lookup_enum(&type_name))
68 }
69 .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", type_name)))?
70 .clone();
71 let variant_def = enum_def
72 .variants
73 .iter()
74 .find(|v| &v.name == variant)
75 .ok_or_else(|| {
76 self.type_error(format!(
77 "Enum '{}' has no variant '{}'",
78 type_name, variant
79 ))
80 })?;
81 if let Some(variant_fields) = &variant_def.fields {
82 if bindings.len() != variant_fields.len() {
83 return Err(self.type_error(format!(
84 "Variant '{}::{}' expects {} bindings, got {}",
85 type_name,
86 variant,
87 variant_fields.len(),
88 bindings.len()
89 )));
90 }
91
92 for (binding, field_type) in bindings.iter().zip(variant_fields.iter()) {
93 let bind_type = if let Some(ref types) = variant_types {
94 if let TypeKind::Generic(_) = &field_type.kind {
95 types.get(0).cloned().unwrap_or_else(|| field_type.clone())
96 } else {
97 field_type.clone()
98 }
99 } else {
100 field_type.clone()
101 };
102 self.validate_is_pattern(binding, &bind_type)?;
103 }
104 } else {
105 if !bindings.is_empty() {
106 return Err(self.type_error(format!(
107 "Variant '{}::{}' is a unit variant and takes no bindings",
108 type_name, variant
109 )));
110 }
111 }
112
113 Ok(())
114 }
115
116 Pattern::Struct { .. } => Ok(()),
117 }
118 }
119
120 pub fn extract_type_narrowings_from_expr(&mut self, expr: &Expr) -> Vec<(String, Type)> {
121 let mut narrowings = Vec::new();
122 match &expr.kind {
123 ExprKind::TypeCheck {
124 expr: scrutinee,
125 check_type: target_type,
126 } => {
127 if let ExprKind::Identifier(var_name) = &scrutinee.kind {
128 if let Some(current_type) = self.env.lookup_variable(var_name) {
129 let narrowed_type = if let TypeKind::Named(name) = &target_type.kind {
130 let resolved = self.resolve_type_key(name);
131 if self.env.lookup_trait(&resolved).is_some() {
132 Type::new(TypeKind::Trait(name.clone()), target_type.span)
133 } else {
134 target_type.clone()
135 }
136 } else {
137 target_type.clone()
138 };
139 match ¤t_type.kind {
140 TypeKind::Unknown => {
141 narrowings.push((var_name.clone(), narrowed_type));
142 }
143
144 TypeKind::Union(types) => {
145 for ty in types {
146 if self.types_equal(ty, target_type) {
147 narrowings.push((var_name.clone(), target_type.clone()));
148 break;
149 }
150 }
151 }
152
153 _ => {}
154 }
155 }
156 }
157 }
158
159 ExprKind::IsPattern {
160 expr: scrutinee,
161 pattern,
162 } => {
163 if let Pattern::TypeCheck(target_type) = pattern {
164 if let ExprKind::Identifier(var_name) = &scrutinee.kind {
165 if let Some(current_type) = self.env.lookup_variable(var_name) {
166 match ¤t_type.kind {
167 TypeKind::Unknown => {
168 narrowings.push((var_name.clone(), target_type.clone()));
169 }
170
171 TypeKind::Union(types) => {
172 for ty in types {
173 if self.types_equal(ty, target_type) {
174 narrowings
175 .push((var_name.clone(), target_type.clone()));
176 break;
177 }
178 }
179 }
180
181 _ => {}
182 }
183 }
184 }
185 }
186 }
187
188 ExprKind::Binary { left, op, right } => {
189 if matches!(op, BinaryOp::And) {
190 narrowings.extend(self.extract_type_narrowings_from_expr(left));
191 narrowings.extend(self.extract_type_narrowings_from_expr(right));
192 }
193 }
194
195 ExprKind::Paren(inner) => {
196 narrowings.extend(self.extract_type_narrowings_from_expr(inner));
197 }
198
199 _ => {}
200 }
201
202 narrowings
203 }
204
205 pub fn extract_all_pattern_bindings_from_expr<'a>(
206 &self,
207 expr: &'a Expr,
208 ) -> Vec<(&'a Expr, Pattern)> {
209 let mut bindings = Vec::new();
210 match &expr.kind {
211 ExprKind::IsPattern {
212 expr: scrutinee,
213 pattern,
214 } => match pattern {
215 Pattern::Enum {
216 bindings: pattern_bindings,
217 ..
218 } if !pattern_bindings.is_empty() => {
219 bindings.push((scrutinee.as_ref(), pattern.clone()));
220 }
221
222 _ => {}
223 },
224 ExprKind::Binary { left, op, right } => {
225 if matches!(op, BinaryOp::And) {
226 bindings.extend(self.extract_all_pattern_bindings_from_expr(left));
227 bindings.extend(self.extract_all_pattern_bindings_from_expr(right));
228 }
229 }
230
231 ExprKind::Paren(inner) => {
232 bindings.extend(self.extract_all_pattern_bindings_from_expr(inner));
233 }
234
235 _ => {}
236 }
237
238 bindings
239 }
240
241 pub fn bind_pattern(&mut self, pattern: &Pattern, scrutinee_type: &Type) -> Result<()> {
242 match pattern {
243 Pattern::Wildcard => Ok(()),
244 Pattern::Identifier(name) => self
245 .env
246 .declare_variable(name.clone(), scrutinee_type.clone()),
247 Pattern::Literal(_) => Ok(()),
248 Pattern::Struct { name: _, fields: _ } => Ok(()),
249 Pattern::Enum {
250 enum_name: _,
251 variant,
252 bindings,
253 } => {
254 let (type_name, variant_types) = match &scrutinee_type.kind {
255 TypeKind::Named(name) => (name.clone(), None),
256 TypeKind::Option(inner) => {
257 ("Option".to_string(), Some(vec![(**inner).clone()]))
258 }
259
260 TypeKind::Result(ok, err) => (
261 "Result".to_string(),
262 Some(vec![(**ok).clone(), (**err).clone()]),
263 ),
264 _ => {
265 return Err(self
266 .type_error(format!("Expected enum type, got '{}'", scrutinee_type)))
267 }
268 };
269 let enum_def = {
270 let key = self.resolve_type_key(&type_name);
271 self.env
272 .lookup_enum(&key)
273 .or_else(|| self.env.lookup_enum(&type_name))
274 }
275 .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", type_name)))?
276 .clone();
277 let variant_def = enum_def
278 .variants
279 .iter()
280 .find(|v| &v.name == variant)
281 .ok_or_else(|| {
282 self.type_error(format!(
283 "Enum '{}' has no variant '{}'",
284 type_name, variant
285 ))
286 })?;
287 if let Some(variant_fields) = &variant_def.fields {
288 if bindings.len() != variant_fields.len() {
289 return Err(self.type_error(format!(
290 "Variant '{}::{}' expects {} bindings, got {}",
291 type_name,
292 variant,
293 variant_fields.len(),
294 bindings.len()
295 )));
296 }
297
298 for (i, (binding, field_type)) in
299 bindings.iter().zip(variant_fields.iter()).enumerate()
300 {
301 let concrete =
302 variant_types
303 .as_ref()
304 .and_then(|types| match type_name.as_str() {
305 "Option" => {
306 if variant == "Some" {
307 types.get(0).cloned()
308 } else {
309 None
310 }
311 }
312
313 "Result" => match variant.as_str() {
314 "Ok" => types.get(0).cloned(),
315 "Err" => types.get(1).cloned(),
316 _ => types.get(i).cloned(),
317 },
318 _ => types.get(i).cloned(),
319 });
320 let bind_type = if let Some(concrete_type) = concrete {
321 concrete_type
322 } else if matches!(field_type.kind, TypeKind::Generic(_)) {
323 Type::new(TypeKind::Unknown, Self::dummy_span())
324 } else {
325 field_type.clone()
326 };
327 self.bind_pattern(binding, &bind_type)?;
328 }
329 } else {
330 if !bindings.is_empty() {
331 return Err(self.type_error(format!(
332 "Variant '{}::{}' is a unit variant and has no bindings",
333 type_name, variant
334 )));
335 }
336 }
337
338 Ok(())
339 }
340
341 Pattern::TypeCheck(_) => Ok(()),
342 }
343 }
344}