Skip to main content

lisette_semantics/pattern_analysis/
normalize.rs

1use crate::store::Store;
2use syntax::ast::{Literal, MatchArm, TypedPattern};
3use syntax::program::Definition;
4use syntax::types::Type;
5
6use super::NormalizedPattern::Wildcard;
7use super::inhabitance::{InhabitanceCache, is_inhabited, is_variant_inhabited};
8use super::types::Row;
9use super::types::*;
10
11fn make_type_key(name: &str, type_args: &[Type]) -> String {
12    if type_args.is_empty() {
13        name.to_string()
14    } else {
15        let args = type_args
16            .iter()
17            .map(|t| t.to_string())
18            .collect::<Vec<_>>()
19            .join(", ");
20        format!("{}<{}>", name, args)
21    }
22}
23
24pub struct NormalizationContext<'a> {
25    pub store: &'a Store,
26    pub cache: &'a InhabitanceCache,
27    pub scrutinee_type: Option<Type>,
28}
29
30fn try_normalize_interface_implementer(
31    ctx: &NormalizationContext,
32    struct_name: &str,
33    arity: usize,
34    args: Vec<NormalizedPattern>,
35    unions: &mut UnionTable,
36) -> Option<NormalizedPattern> {
37    let scrutinee_ty = ctx.scrutinee_type.as_ref()?;
38    let peeled = ctx.store.peel_alias(&scrutinee_ty.resolve());
39    let Type::Constructor {
40        id: interface_id,
41        params: interface_params,
42        ..
43    } = &peeled
44    else {
45        return None;
46    };
47    ctx.store.get_interface(interface_id)?;
48
49    let interface_type_name = make_type_key(interface_id, interface_params);
50    let struct_ctor = Constructor {
51        tag_id: struct_name.to_string(),
52        arity,
53    };
54
55    if let Some(union) = unions.get_mut(&interface_type_name) {
56        let mut found = false;
57        let mut unknown_pos = union.len();
58        for (i, c) in union.iter().enumerate() {
59            if c.tag_id == struct_name {
60                found = true;
61                break;
62            }
63            if c.tag_id == INTERFACE_UNKNOWN_TAG {
64                unknown_pos = i;
65            }
66        }
67        if !found {
68            union.insert(unknown_pos, struct_ctor);
69        }
70    } else {
71        unions.insert(
72            interface_type_name.clone(),
73            vec![
74                struct_ctor,
75                Constructor {
76                    tag_id: INTERFACE_UNKNOWN_TAG.to_string(),
77                    arity: 0,
78                },
79            ],
80        );
81    }
82
83    Some(NormalizedPattern::Constructor {
84        type_name: interface_type_name,
85        tag: struct_name.to_string(),
86        args,
87    })
88}
89
90pub fn normalize_arm(
91    arm: &MatchArm,
92    unions: &mut UnionTable,
93    ctx: &NormalizationContext,
94) -> Vec<Row> {
95    let typed_pattern = arm
96        .typed_pattern
97        .as_ref()
98        .expect("typed pattern should be populated during inference");
99
100    match typed_pattern {
101        TypedPattern::Or { alternatives } => alternatives
102            .iter()
103            .map(|alt| vec![normalize_typed_pattern(alt, unions, ctx)])
104            .collect(),
105        _ => {
106            vec![vec![normalize_typed_pattern(typed_pattern, unions, ctx)]]
107        }
108    }
109}
110
111pub fn normalize_typed_pattern(
112    typed_pattern: &TypedPattern,
113    unions: &mut UnionTable,
114    ctx: &NormalizationContext,
115) -> NormalizedPattern {
116    match typed_pattern {
117        TypedPattern::Wildcard => Wildcard,
118
119        TypedPattern::Literal(literal) => {
120            if let Literal::Boolean(b) = literal {
121                return normalize_boolean(*b, unions);
122            }
123
124            NormalizedPattern::Literal(literal.clone())
125        }
126
127        TypedPattern::EnumVariant {
128            enum_name,
129            variant_name,
130            fields,
131            type_args,
132            ..
133        } => {
134            let patterns: Vec<NormalizedPattern> = fields
135                .iter()
136                .map(|f| normalize_typed_pattern(f, unions, ctx))
137                .collect();
138
139            let enum_def = ctx.store.get_definition(enum_name);
140
141            if let Some(Definition::Struct {
142                fields: struct_fields,
143                ..
144            }) = enum_def
145            {
146                let arity = struct_fields.len();
147                let mut args = patterns.clone();
148                while args.len() < arity {
149                    args.push(Wildcard);
150                }
151                if let Some(normalized) =
152                    try_normalize_interface_implementer(ctx, enum_name, arity, args, unions)
153                {
154                    return normalized;
155                }
156            }
157
158            let type_name = make_type_key(enum_name, type_args);
159
160            if unions.get(&type_name).is_none() {
161                let alternatives = match enum_def {
162                    Some(Definition::Enum {
163                        variants, generics, ..
164                    }) => variants
165                        .iter()
166                        .filter(|v| {
167                            is_variant_inhabited(v, type_args, generics, ctx.store, ctx.cache)
168                        })
169                        .map(|v| Constructor {
170                            tag_id: format!("{}.{}", enum_name, v.name),
171                            arity: v.fields.len(),
172                        })
173                        .collect(),
174                    Some(Definition::ValueEnum { variants, .. }) => {
175                        let mut alts: Vec<Constructor> = variants
176                            .iter()
177                            .map(|v| Constructor {
178                                tag_id: format!("{}.{}", enum_name, v.name),
179                                arity: 0,
180                            })
181                            .collect();
182                        alts.push(Constructor {
183                            tag_id: format!("{}.__value_enum_unknown__", enum_name),
184                            arity: 0,
185                        });
186                        alts
187                    }
188                    _ => vec![],
189                };
190
191                unions.insert(type_name.clone(), alternatives);
192            }
193
194            let variant_name = variant_name.rsplit('.').next().unwrap_or(variant_name);
195            let tag = format!("{}.{}", enum_name, variant_name);
196
197            NormalizedPattern::Constructor {
198                type_name,
199                tag,
200                args: patterns,
201            }
202        }
203
204        TypedPattern::EnumStructVariant {
205            enum_name,
206            variant_name,
207            variant_fields,
208            pattern_fields,
209            type_args,
210        } => {
211            let patterns = variant_fields
212                .iter()
213                .map(|f| {
214                    pattern_fields
215                        .iter()
216                        .find_map(|(name, pattern)| {
217                            if *name == f.name {
218                                Some(normalize_typed_pattern(pattern, unions, ctx))
219                            } else {
220                                None
221                            }
222                        })
223                        .unwrap_or(Wildcard)
224                })
225                .collect();
226
227            let type_name = make_type_key(enum_name, type_args);
228
229            if unions.get(&type_name).is_none() {
230                let alternatives = match ctx.store.get_definition(enum_name) {
231                    Some(Definition::Enum {
232                        variants, generics, ..
233                    }) => variants
234                        .iter()
235                        .filter(|v| {
236                            is_variant_inhabited(v, type_args, generics, ctx.store, ctx.cache)
237                        })
238                        .map(|v| Constructor {
239                            tag_id: format!("{}.{}", enum_name, v.name),
240                            arity: v.fields.len(),
241                        })
242                        .collect(),
243                    _ => vec![],
244                };
245
246                unions.insert(type_name.clone(), alternatives);
247            }
248
249            let variant_name = variant_name.rsplit('.').next().unwrap_or(variant_name);
250            let tag = format!("{}.{}", enum_name, variant_name);
251
252            NormalizedPattern::Constructor {
253                type_name,
254                tag,
255                args: patterns,
256            }
257        }
258
259        TypedPattern::Struct {
260            struct_name,
261            struct_fields,
262            pattern_fields,
263            type_args,
264        } => {
265            let patterns: Vec<NormalizedPattern> = struct_fields
266                .iter()
267                .map(|f| {
268                    pattern_fields
269                        .iter()
270                        .find_map(|(name, pattern)| {
271                            if *name == f.name {
272                                Some(normalize_typed_pattern(pattern, unions, ctx))
273                            } else {
274                                None
275                            }
276                        })
277                        .unwrap_or(Wildcard)
278                })
279                .collect();
280
281            if let Some(normalized) = try_normalize_interface_implementer(
282                ctx,
283                struct_name,
284                struct_fields.len(),
285                patterns.clone(),
286                unions,
287            ) {
288                return normalized;
289            }
290
291            let type_name = make_type_key(struct_name, type_args);
292
293            if unions.get(&type_name).is_none() {
294                let is_inhabited = ctx
295                    .store
296                    .get_definition(struct_name)
297                    .map(|definition| match definition {
298                        Definition::Struct {
299                            generics, fields, ..
300                        } => super::inhabitance::is_struct_inhabited(
301                            fields, type_args, generics, ctx.store, ctx.cache,
302                        ),
303                        _ => true,
304                    })
305                    .unwrap_or(true);
306
307                if is_inhabited {
308                    let constructor = Constructor {
309                        tag_id: struct_name.to_string(),
310                        arity: struct_fields.len(),
311                    };
312                    unions.insert(type_name.clone(), vec![constructor]);
313                } else {
314                    unions.insert(type_name.clone(), vec![]);
315                }
316            }
317
318            NormalizedPattern::Constructor {
319                type_name,
320                tag: struct_name.to_string(),
321                args: patterns,
322            }
323        }
324
325        TypedPattern::Slice {
326            prefix,
327            has_rest,
328            element_type,
329        } => normalize_slice(prefix, *has_rest, element_type, unions, ctx),
330
331        TypedPattern::Tuple { arity, elements } => normalize_tuple(elements, *arity, unions, ctx),
332
333        TypedPattern::Or { .. } => {
334            unreachable!("Or-pattern should be handled by normalize_arm")
335        }
336    }
337}
338
339/// Normalize a slice pattern into nested EmptySlice/NonEmptySlice constructors.
340///
341/// Slice is modeled as a 2-variant type:
342/// - EmptySlice: represents []
343/// - NonEmptySlice(head, tail): represents [head, ..tail]
344///
345/// Examples:
346/// - [] → EmptySlice
347/// - [a] → NonEmptySlice(a, EmptySlice)
348/// - [a, b] → NonEmptySlice(a, NonEmptySlice(b, EmptySlice))
349/// - [a, ..rest] → NonEmptySlice(a, Wildcard)
350/// - [..] → Wildcard (matches any slice)
351fn normalize_slice(
352    prefix: &[TypedPattern],
353    has_rest: bool,
354    element_type: &Type,
355    unions: &mut UnionTable,
356    ctx: &NormalizationContext,
357) -> NormalizedPattern {
358    let type_name = make_type_key("Slice", std::slice::from_ref(element_type));
359    if unions.get(&type_name).is_none() {
360        let element_inhabited = is_inhabited(element_type, ctx.store, ctx.cache);
361
362        let mut constructors = vec![Constructor {
363            tag_id: "EmptySlice".to_string(),
364            arity: 0,
365        }];
366
367        if element_inhabited {
368            constructors.push(Constructor {
369                tag_id: "NonEmptySlice".to_string(),
370                arity: 2, // head and tail
371            });
372        }
373
374        unions.insert(type_name.clone(), constructors);
375    }
376
377    if prefix.is_empty() && has_rest {
378        return Wildcard;
379    }
380
381    if prefix.is_empty() && !has_rest {
382        return NormalizedPattern::Constructor {
383            type_name,
384            tag: "EmptySlice".to_string(),
385            args: vec![],
386        };
387    }
388
389    let tail = if has_rest {
390        Wildcard
391    } else {
392        NormalizedPattern::Constructor {
393            type_name: type_name.clone(),
394            tag: "EmptySlice".to_string(),
395            args: vec![],
396        }
397    };
398
399    let mut result = tail;
400    for element in prefix.iter().rev() {
401        let head = normalize_typed_pattern(element, unions, ctx);
402        result = NormalizedPattern::Constructor {
403            type_name: type_name.clone(),
404            tag: "NonEmptySlice".to_string(),
405            args: vec![head, result],
406        };
407    }
408
409    result
410}
411
412fn normalize_tuple(
413    elements: &[TypedPattern],
414    arity: usize,
415    unions: &mut UnionTable,
416    ctx: &NormalizationContext,
417) -> NormalizedPattern {
418    let type_name = format!("Tuple{}", arity);
419
420    if unions.get(&type_name).is_none() {
421        let constructor = Constructor {
422            tag_id: type_name.clone(),
423            arity,
424        };
425        unions.insert(type_name.clone(), vec![constructor]);
426    }
427
428    let patterns = elements
429        .iter()
430        .map(|e| normalize_typed_pattern(e, unions, ctx))
431        .collect();
432
433    NormalizedPattern::Constructor {
434        type_name: type_name.clone(),
435        tag: type_name,
436        args: patterns,
437    }
438}
439
440fn normalize_boolean(boolean: bool, unions: &mut UnionTable) -> NormalizedPattern {
441    let type_name = "Bool".to_string();
442
443    if unions.get(&type_name).is_none() {
444        let make_alt = |b: bool| Constructor {
445            tag_id: b.to_string(),
446            arity: 0,
447        };
448
449        unions.insert(type_name.clone(), vec![make_alt(true), make_alt(false)]);
450    }
451
452    NormalizedPattern::Constructor {
453        type_name,
454        tag: boolean.to_string(),
455        args: vec![],
456    }
457}