lust/typechecker/expr_checker/
patterns.rs1use super::*;
2use alloc::{
3 format,
4 string::{String, ToString},
5 vec,
6 vec::Vec,
7};
8impl TypeChecker {
9 pub fn validate_is_pattern(&mut self, pattern: &Pattern, scrutinee_type: &Type) -> Result<()> {
10 match pattern {
11 Pattern::Wildcard | Pattern::Literal(_) | Pattern::Identifier(_) => Ok(()),
12 Pattern::TypeCheck(check_type) => {
13 let _ = check_type;
14 Ok(())
15 }
16
17 Pattern::Enum {
18 enum_name: _,
19 variant,
20 bindings,
21 } => {
22 let (type_name, variant_types) = match &scrutinee_type.kind {
23 TypeKind::Named(name) => (name.clone(), None),
24 TypeKind::Option(inner) => {
25 ("Option".to_string(), Some(vec![(**inner).clone()]))
26 }
27
28 TypeKind::Result(ok, err) => (
29 "Result".to_string(),
30 Some(vec![(**ok).clone(), (**err).clone()]),
31 ),
32 TypeKind::Union(types) => {
33 for ty in types.iter() {
34 if let TypeKind::Named(name) = &ty.kind {
35 if let Some(_) = {
36 let key = self.resolve_type_key(name);
37 self.env
38 .lookup_enum(&key)
39 .or_else(|| self.env.lookup_enum(name))
40 } {
41 return Ok(());
42 }
43 }
44
45 if matches!(ty.kind, TypeKind::Option(_) | TypeKind::Result(_, _)) {
46 return Ok(());
47 }
48 }
49
50 return Err(self.type_error(format!(
51 "Union type '{}' does not contain enum types compatible with variant '{}'",
52 scrutinee_type, variant
53 )));
54 }
55
56 _ => {
57 return Err(self.type_error(format!(
58 "Cannot use enum pattern on non-enum type '{}'",
59 scrutinee_type
60 )))
61 }
62 };
63 let enum_def = {
64 let key = self.resolve_type_key(&type_name);
65 self.env
66 .lookup_enum(&key)
67 .or_else(|| self.env.lookup_enum(&type_name))
68 }
69 .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", type_name)))?
70 .clone();
71 let variant_def = enum_def
72 .variants
73 .iter()
74 .find(|v| &v.name == variant)
75 .ok_or_else(|| {
76 self.type_error(format!(
77 "Enum '{}' has no variant '{}'",
78 type_name, variant
79 ))
80 })?;
81 if let Some(variant_fields) = &variant_def.fields {
82 if bindings.len() != variant_fields.len() {
83 return Err(self.type_error(format!(
84 "Variant '{}::{}' expects {} bindings, got {}",
85 type_name,
86 variant,
87 variant_fields.len(),
88 bindings.len()
89 )));
90 }
91
92 for (binding, field_type) in bindings.iter().zip(variant_fields.iter()) {
93 let bind_type = if let Some(ref types) = variant_types {
94 if let TypeKind::Generic(_) = &field_type.kind {
95 types.get(0).cloned().unwrap_or_else(|| field_type.clone())
96 } else {
97 field_type.clone()
98 }
99 } else {
100 field_type.clone()
101 };
102 self.validate_is_pattern(binding, &bind_type)?;
103 }
104 } else {
105 if !bindings.is_empty() {
106 return Err(self.type_error(format!(
107 "Variant '{}::{}' is a unit variant and takes no bindings",
108 type_name, variant
109 )));
110 }
111 }
112
113 Ok(())
114 }
115
116 Pattern::Struct { .. } => Ok(()),
117 }
118 }
119
120 pub fn extract_type_narrowings_from_expr(&mut self, expr: &Expr) -> Vec<(String, Type)> {
121 let mut narrowings = Vec::new();
122 match &expr.kind {
123 ExprKind::TypeCheck {
124 expr: scrutinee,
125 check_type: target_type,
126 } => {
127 if let ExprKind::Identifier(var_name) = &scrutinee.kind {
128 if let Some(current_type) = self.env.lookup_variable(var_name) {
129 let narrowed_type = if let TypeKind::Named(name) = &target_type.kind {
130 let resolved = self.resolve_type_key(name);
131 if self.env.lookup_trait(&resolved).is_some() {
132 Type::new(TypeKind::Trait(name.clone()), target_type.span)
133 } else {
134 target_type.clone()
135 }
136 } else {
137 target_type.clone()
138 };
139 match ¤t_type.kind {
140 TypeKind::Unknown => {
141 narrowings.push((var_name.clone(), narrowed_type));
142 }
143
144 TypeKind::Union(types) => {
145 for ty in types {
146 if self.types_equal(ty, target_type) {
147 narrowings.push((var_name.clone(), target_type.clone()));
148 break;
149 }
150 }
151 }
152
153 _ => {}
154 }
155 }
156 }
157 }
158
159 ExprKind::IsPattern {
160 expr: scrutinee,
161 pattern,
162 } => {
163 if let Pattern::TypeCheck(target_type) = pattern {
164 if let ExprKind::Identifier(var_name) = &scrutinee.kind {
165 if let Some(current_type) = self.env.lookup_variable(var_name) {
166 match ¤t_type.kind {
167 TypeKind::Unknown => {
168 narrowings.push((var_name.clone(), target_type.clone()));
169 }
170
171 TypeKind::Union(types) => {
172 for ty in types {
173 if self.types_equal(ty, target_type) {
174 narrowings
175 .push((var_name.clone(), target_type.clone()));
176 break;
177 }
178 }
179 }
180
181 _ => {}
182 }
183 }
184 }
185 }
186 }
187
188 ExprKind::Binary { left, op, right } => {
189 if matches!(op, BinaryOp::And) {
190 narrowings.extend(self.extract_type_narrowings_from_expr(left));
191 narrowings.extend(self.extract_type_narrowings_from_expr(right));
192 }
193 }
194
195 _ => {}
196 }
197
198 narrowings
199 }
200
201 pub fn extract_all_pattern_bindings_from_expr<'a>(
202 &self,
203 expr: &'a Expr,
204 ) -> Vec<(&'a Expr, Pattern)> {
205 let mut bindings = Vec::new();
206 match &expr.kind {
207 ExprKind::IsPattern {
208 expr: scrutinee,
209 pattern,
210 } => match pattern {
211 Pattern::Enum {
212 bindings: pattern_bindings,
213 ..
214 } if !pattern_bindings.is_empty() => {
215 bindings.push((scrutinee.as_ref(), pattern.clone()));
216 }
217
218 _ => {}
219 },
220 ExprKind::Binary { left, op, right } => {
221 if matches!(op, BinaryOp::And) {
222 bindings.extend(self.extract_all_pattern_bindings_from_expr(left));
223 bindings.extend(self.extract_all_pattern_bindings_from_expr(right));
224 }
225 }
226
227 _ => {}
228 }
229
230 bindings
231 }
232
233 pub fn bind_pattern(&mut self, pattern: &Pattern, scrutinee_type: &Type) -> Result<()> {
234 match pattern {
235 Pattern::Wildcard => Ok(()),
236 Pattern::Identifier(name) => self
237 .env
238 .declare_variable(name.clone(), scrutinee_type.clone()),
239 Pattern::Literal(_) => Ok(()),
240 Pattern::Struct { name: _, fields: _ } => Ok(()),
241 Pattern::Enum {
242 enum_name: _,
243 variant,
244 bindings,
245 } => {
246 let (type_name, variant_types) = match &scrutinee_type.kind {
247 TypeKind::Named(name) => (name.clone(), None),
248 TypeKind::Option(inner) => {
249 ("Option".to_string(), Some(vec![(**inner).clone()]))
250 }
251
252 TypeKind::Result(ok, err) => (
253 "Result".to_string(),
254 Some(vec![(**ok).clone(), (**err).clone()]),
255 ),
256 _ => {
257 return Err(self
258 .type_error(format!("Expected enum type, got '{}'", scrutinee_type)))
259 }
260 };
261 let enum_def = {
262 let key = self.resolve_type_key(&type_name);
263 self.env
264 .lookup_enum(&key)
265 .or_else(|| self.env.lookup_enum(&type_name))
266 }
267 .ok_or_else(|| self.type_error(format!("Undefined enum '{}'", type_name)))?
268 .clone();
269 let variant_def = enum_def
270 .variants
271 .iter()
272 .find(|v| &v.name == variant)
273 .ok_or_else(|| {
274 self.type_error(format!(
275 "Enum '{}' has no variant '{}'",
276 type_name, variant
277 ))
278 })?;
279 if let Some(variant_fields) = &variant_def.fields {
280 if bindings.len() != variant_fields.len() {
281 return Err(self.type_error(format!(
282 "Variant '{}::{}' expects {} bindings, got {}",
283 type_name,
284 variant,
285 variant_fields.len(),
286 bindings.len()
287 )));
288 }
289
290 for (i, (binding, field_type)) in
291 bindings.iter().zip(variant_fields.iter()).enumerate()
292 {
293 let concrete =
294 variant_types
295 .as_ref()
296 .and_then(|types| match type_name.as_str() {
297 "Option" => {
298 if variant == "Some" {
299 types.get(0).cloned()
300 } else {
301 None
302 }
303 }
304
305 "Result" => match variant.as_str() {
306 "Ok" => types.get(0).cloned(),
307 "Err" => types.get(1).cloned(),
308 _ => types.get(i).cloned(),
309 },
310 _ => types.get(i).cloned(),
311 });
312 let bind_type = if let Some(concrete_type) = concrete {
313 concrete_type
314 } else if matches!(field_type.kind, TypeKind::Generic(_)) {
315 Type::new(TypeKind::Unknown, Self::dummy_span())
316 } else {
317 field_type.clone()
318 };
319 self.bind_pattern(binding, &bind_type)?;
320 }
321 } else {
322 if !bindings.is_empty() {
323 return Err(self.type_error(format!(
324 "Variant '{}::{}' is a unit variant and has no bindings",
325 type_name, variant
326 )));
327 }
328 }
329
330 Ok(())
331 }
332
333 Pattern::TypeCheck(_) => Ok(()),
334 }
335 }
336}