Skip to main content

aver/ir/
matches.rs

1use crate::ast::{BinOp, Expr, Literal, MatchArm, Pattern};
2
3use super::{CallLowerCtx, SemanticConstructor, WrapperKind, classify_constructor_name};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum DispatchLiteral {
7    Int(i64),
8    Float(String),
9    Bool(bool),
10    Str(String),
11    Unit,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum SemanticDispatchPattern {
16    Literal(DispatchLiteral),
17    EmptyList,
18    NoneValue,
19    WrapperTag(WrapperKind),
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub struct BoolMatchShape {
24    pub true_arm_index: usize,
25    pub false_arm_index: usize,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum BoolCompareOp {
30    Eq,
31    Lt,
32    Gt,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq)]
36pub enum BoolSubjectPlan<'a> {
37    Expr(&'a Expr),
38    Compare {
39        lhs: &'a Expr,
40        rhs: &'a Expr,
41        op: BoolCompareOp,
42        invert: bool,
43    },
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub struct ListMatchShape {
48    pub empty_arm_index: usize,
49    pub cons_arm_index: usize,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum DispatchBindingPlan {
54    None,
55    WrapperPayload(String),
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct DispatchArmPlan {
60    pub pattern: SemanticDispatchPattern,
61    pub arm_index: usize,
62    pub binding: DispatchBindingPlan,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct DispatchDefaultPlan {
67    pub arm_index: usize,
68    pub binding_name: Option<String>,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct DispatchTableShape {
73    pub entries: Vec<DispatchArmPlan>,
74    pub default_arm: Option<DispatchDefaultPlan>,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum MatchDispatchPlan {
79    Bool(BoolMatchShape),
80    List(ListMatchShape),
81    Table(DispatchTableShape),
82}
83
84pub fn classify_bool_match_shape_from_patterns(patterns: &[&Pattern]) -> Option<BoolMatchShape> {
85    if patterns.len() != 2 {
86        return None;
87    }
88
89    match (patterns[0], patterns[1]) {
90        (Pattern::Literal(Literal::Bool(true)), Pattern::Literal(Literal::Bool(false))) => {
91            Some(BoolMatchShape {
92                true_arm_index: 0,
93                false_arm_index: 1,
94            })
95        }
96        (Pattern::Literal(Literal::Bool(false)), Pattern::Literal(Literal::Bool(true))) => {
97            Some(BoolMatchShape {
98                true_arm_index: 1,
99                false_arm_index: 0,
100            })
101        }
102        (Pattern::Literal(Literal::Bool(true)), Pattern::Wildcard | Pattern::Ident(_)) => {
103            Some(BoolMatchShape {
104                true_arm_index: 0,
105                false_arm_index: 1,
106            })
107        }
108        _ => None,
109    }
110}
111
112pub fn classify_list_match_shape_from_patterns(patterns: &[&Pattern]) -> Option<ListMatchShape> {
113    if patterns.len() != 2 {
114        return None;
115    }
116
117    match (patterns[0], patterns[1]) {
118        (Pattern::EmptyList, Pattern::Cons(_, _)) => Some(ListMatchShape {
119            empty_arm_index: 0,
120            cons_arm_index: 1,
121        }),
122        (Pattern::Cons(_, _), Pattern::EmptyList) => Some(ListMatchShape {
123            empty_arm_index: 1,
124            cons_arm_index: 0,
125        }),
126        _ => None,
127    }
128}
129
130pub fn classify_dispatch_table_shape_from_patterns(
131    patterns: &[&Pattern],
132    ctx: &impl CallLowerCtx,
133) -> Option<DispatchTableShape> {
134    if patterns.len() < 2 {
135        return None;
136    }
137
138    let has_default = matches!(patterns.last(), Some(Pattern::Wildcard | Pattern::Ident(_)));
139    let dispatchable_end = if has_default {
140        patterns.len() - 1
141    } else {
142        patterns.len()
143    };
144
145    let mut entries = Vec::new();
146    for (arm_index, pattern) in patterns[..dispatchable_end].iter().enumerate() {
147        let semantic = classify_dispatch_pattern(pattern, ctx)?;
148        entries.push(DispatchArmPlan {
149            binding: classify_dispatch_binding(pattern, &semantic),
150            pattern: semantic,
151            arm_index,
152        });
153    }
154
155    if entries.len() < 2 {
156        return None;
157    }
158
159    Some(DispatchTableShape {
160        entries,
161        default_arm: has_default
162            .then(|| classify_default_arm_plan(patterns[patterns.len() - 1], patterns.len() - 1)),
163    })
164}
165
166pub fn classify_match_dispatch_plan_from_patterns(
167    patterns: &[&Pattern],
168    ctx: &impl CallLowerCtx,
169) -> Option<MatchDispatchPlan> {
170    if let Some(shape) = classify_bool_match_shape_from_patterns(patterns) {
171        return Some(MatchDispatchPlan::Bool(shape));
172    }
173
174    if let Some(shape) = classify_list_match_shape_from_patterns(patterns) {
175        return Some(MatchDispatchPlan::List(shape));
176    }
177
178    classify_dispatch_table_shape_from_patterns(patterns, ctx).map(MatchDispatchPlan::Table)
179}
180
181pub fn classify_dispatch_pattern(
182    pattern: &Pattern,
183    ctx: &impl CallLowerCtx,
184) -> Option<SemanticDispatchPattern> {
185    match pattern {
186        Pattern::Literal(lit) => Some(SemanticDispatchPattern::Literal(dispatch_literal_from_ast(
187            lit,
188        ))),
189        Pattern::EmptyList => Some(SemanticDispatchPattern::EmptyList),
190        Pattern::Constructor(name, bindings) => match classify_constructor_name(name, ctx) {
191            SemanticConstructor::NoneValue if bindings.is_empty() => {
192                Some(SemanticDispatchPattern::NoneValue)
193            }
194            SemanticConstructor::Wrapper(kind) if bindings.len() <= 1 => {
195                Some(SemanticDispatchPattern::WrapperTag(kind))
196            }
197            _ => None,
198        },
199        _ => None,
200    }
201}
202
203pub fn classify_bool_match_shape(arms: &[MatchArm]) -> Option<BoolMatchShape> {
204    let patterns: Vec<&Pattern> = arms.iter().map(|arm| &arm.pattern).collect();
205    classify_bool_match_shape_from_patterns(&patterns)
206}
207
208pub fn classify_bool_subject_plan(subject: &Expr) -> BoolSubjectPlan<'_> {
209    let Expr::BinOp(op, lhs, rhs) = subject else {
210        return BoolSubjectPlan::Expr(subject);
211    };
212
213    match op {
214        BinOp::Eq => BoolSubjectPlan::Compare {
215            lhs,
216            rhs,
217            op: BoolCompareOp::Eq,
218            invert: false,
219        },
220        BinOp::Lt => BoolSubjectPlan::Compare {
221            lhs,
222            rhs,
223            op: BoolCompareOp::Lt,
224            invert: false,
225        },
226        BinOp::Gt => BoolSubjectPlan::Compare {
227            lhs,
228            rhs,
229            op: BoolCompareOp::Gt,
230            invert: false,
231        },
232        BinOp::Neq => BoolSubjectPlan::Compare {
233            lhs,
234            rhs,
235            op: BoolCompareOp::Eq,
236            invert: true,
237        },
238        BinOp::Gte => BoolSubjectPlan::Compare {
239            lhs,
240            rhs,
241            op: BoolCompareOp::Lt,
242            invert: true,
243        },
244        BinOp::Lte => BoolSubjectPlan::Compare {
245            lhs,
246            rhs,
247            op: BoolCompareOp::Gt,
248            invert: true,
249        },
250        BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div => BoolSubjectPlan::Expr(subject),
251    }
252}
253
254pub fn classify_list_match_shape(arms: &[MatchArm]) -> Option<ListMatchShape> {
255    let patterns: Vec<&Pattern> = arms.iter().map(|arm| &arm.pattern).collect();
256    classify_list_match_shape_from_patterns(&patterns)
257}
258
259pub fn classify_dispatch_table_shape(
260    arms: &[MatchArm],
261    ctx: &impl CallLowerCtx,
262) -> Option<DispatchTableShape> {
263    let patterns: Vec<&Pattern> = arms.iter().map(|arm| &arm.pattern).collect();
264    classify_dispatch_table_shape_from_patterns(&patterns, ctx)
265}
266
267pub fn classify_match_dispatch_plan(
268    arms: &[MatchArm],
269    ctx: &impl CallLowerCtx,
270) -> Option<MatchDispatchPlan> {
271    let patterns: Vec<&Pattern> = arms.iter().map(|arm| &arm.pattern).collect();
272    classify_match_dispatch_plan_from_patterns(&patterns, ctx)
273}
274
275fn dispatch_literal_from_ast(lit: &Literal) -> DispatchLiteral {
276    match lit {
277        Literal::Int(i) => DispatchLiteral::Int(*i),
278        Literal::Float(f) => DispatchLiteral::Float(f.to_string()),
279        Literal::Bool(b) => DispatchLiteral::Bool(*b),
280        Literal::Str(s) => DispatchLiteral::Str(s.clone()),
281        Literal::Unit => DispatchLiteral::Unit,
282    }
283}
284
285fn classify_dispatch_binding(
286    pattern: &Pattern,
287    semantic: &SemanticDispatchPattern,
288) -> DispatchBindingPlan {
289    match (pattern, semantic) {
290        (Pattern::Constructor(_, bindings), SemanticDispatchPattern::WrapperTag(_))
291            if !bindings.is_empty() && bindings[0] != "_" =>
292        {
293            DispatchBindingPlan::WrapperPayload(bindings[0].clone())
294        }
295        _ => DispatchBindingPlan::None,
296    }
297}
298
299fn classify_default_arm_plan(pattern: &Pattern, arm_index: usize) -> DispatchDefaultPlan {
300    let binding_name = match pattern {
301        Pattern::Ident(name) if name != "_" => Some(name.clone()),
302        _ => None,
303    };
304
305    DispatchDefaultPlan {
306        arm_index,
307        binding_name,
308    }
309}