Skip to main content

tidepool_optimize/
case_reduce.rs

1use tidepool_eval::{Changed, Pass};
2use tidepool_repr::{get_children, replace_subtree, AltCon, CoreExpr, CoreFrame};
3
4/// A pass that performs case-of-known-constructor and case-of-known-literal reductions.
5pub struct CaseReduce;
6
7impl Pass for CaseReduce {
8    fn run(&self, expr: &mut CoreExpr) -> Changed {
9        if expr.nodes.is_empty() {
10            return false;
11        }
12        match try_case_reduce(expr) {
13            Some(new_expr) => {
14                *expr = new_expr;
15                true
16            }
17            None => false,
18        }
19    }
20
21    fn name(&self) -> &str {
22        "CaseReduce"
23    }
24}
25
26fn try_case_reduce(expr: &CoreExpr) -> Option<CoreExpr> {
27    try_case_reduce_at(expr, expr.nodes.len() - 1)
28}
29
30fn try_case_reduce_at(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
31    match &expr.nodes[idx] {
32        CoreFrame::Case {
33            scrutinee,
34            binder,
35            alts,
36        } => {
37            match &expr.nodes[*scrutinee] {
38                CoreFrame::Con { tag, fields } => {
39                    // Find matching DataAlt or Default
40                    let alt = alts
41                        .iter()
42                        .find(|a| matches!(&a.con, AltCon::DataAlt(t) if t == tag))
43                        .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
44
45                    if let Some(alt) = alt {
46                        // Arity check for DataAlt: binders must match fields.
47                        // If mismatch, skip this reduction (malformed IR).
48                        if let AltCon::DataAlt(_) = &alt.con {
49                            if alt.binders.len() != fields.len() {
50                                return try_children(expr, idx);
51                            }
52                        }
53
54                        let mut body = expr.extract_subtree(alt.body);
55                        // Bind fields to alt binders
56                        if let AltCon::DataAlt(_) = &alt.con {
57                            for (alt_binder, field_idx) in alt.binders.iter().zip(fields.iter()) {
58                                let field_tree = expr.extract_subtree(*field_idx);
59                                body = tidepool_repr::subst::subst(&body, *alt_binder, &field_tree);
60                            }
61                        }
62                        // Substitute case binder with scrutinee
63                        let scrut_tree = expr.extract_subtree(*scrutinee);
64                        body = tidepool_repr::subst::subst(&body, *binder, &scrut_tree);
65                        Some(replace_subtree(expr, idx, &body))
66                    } else {
67                        // No matching alt — try children
68                        try_children(expr, idx)
69                    }
70                }
71                CoreFrame::Lit(lit) => {
72                    let alt = alts
73                        .iter()
74                        .find(|a| matches!(&a.con, AltCon::LitAlt(l) if l == lit))
75                        .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
76
77                    if let Some(alt) = alt {
78                        let mut body = expr.extract_subtree(alt.body);
79                        // Substitute case binder with scrutinee literal
80                        let scrut_tree = expr.extract_subtree(*scrutinee);
81                        body = tidepool_repr::subst::subst(&body, *binder, &scrut_tree);
82                        Some(replace_subtree(expr, idx, &body))
83                    } else {
84                        try_children(expr, idx)
85                    }
86                }
87                _ => try_children(expr, idx),
88            }
89        }
90        _ => try_children(expr, idx),
91    }
92}
93
94fn try_children(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
95    let children = get_children(&expr.nodes[idx]);
96    for child in children {
97        if let Some(result) = try_case_reduce_at(expr, child) {
98            return Some(result);
99        }
100    }
101    None
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use tidepool_eval::env::Env;
108    use tidepool_eval::heap::VecHeap;
109    use tidepool_eval::value::Value;
110    use tidepool_repr::{Alt, DataConId, Literal, PrimOpKind, VarId};
111
112    #[test]
113    fn test_case_known_con() {
114        // case Con(tag=1, [42]) of w { DataAlt(1) [y] -> y }
115        let nodes = vec![
116            CoreFrame::Lit(Literal::LitInt(42)), // 0
117            CoreFrame::Con {
118                tag: DataConId(1),
119                fields: vec![0],
120            }, // 1
121            CoreFrame::Var(VarId(3)),            // 2: y
122            CoreFrame::Case {
123                scrutinee: 1,
124                binder: VarId(2), // w
125                alts: vec![Alt {
126                    con: AltCon::DataAlt(DataConId(1)),
127                    binders: vec![VarId(3)],
128                    body: 2,
129                }],
130            }, // 3
131        ];
132        let mut expr = CoreExpr { nodes };
133        let pass = CaseReduce;
134        let changed = pass.run(&mut expr);
135        assert!(changed);
136        // Result should be Lit(42)
137        assert_eq!(expr.nodes.len(), 1);
138        assert!(matches!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42))));
139    }
140
141    #[test]
142    fn test_case_known_con_pair() {
143        // case Con(tag=1, [1, 2]) of w { DataAlt(1) [a, b] -> PrimOp(IntAdd, [a, b]) }
144        let nodes = vec![
145            CoreFrame::Lit(Literal::LitInt(1)), // 0
146            CoreFrame::Lit(Literal::LitInt(2)), // 1
147            CoreFrame::Con {
148                tag: DataConId(1),
149                fields: vec![0, 1],
150            }, // 2
151            CoreFrame::Var(VarId(10)),          // 3: a
152            CoreFrame::Var(VarId(11)),          // 4: b
153            CoreFrame::PrimOp {
154                op: PrimOpKind::IntAdd,
155                args: vec![3, 4],
156            }, // 5
157            CoreFrame::Case {
158                scrutinee: 2,
159                binder: VarId(12),
160                alts: vec![Alt {
161                    con: AltCon::DataAlt(DataConId(1)),
162                    binders: vec![VarId(10), VarId(11)],
163                    body: 5,
164                }],
165            }, // 6
166        ];
167        let mut expr = CoreExpr { nodes };
168        let pass = CaseReduce;
169
170        let mut heap = VecHeap::new();
171        let val_before = tidepool_eval::eval(&expr, &Env::new(), &mut heap).unwrap();
172
173        let changed = pass.run(&mut expr);
174        assert!(changed);
175
176        let mut heap2 = VecHeap::new();
177        let val_after = tidepool_eval::eval(&expr, &Env::new(), &mut heap2).unwrap();
178
179        match (val_before, val_after) {
180            (Value::Lit(l1), Value::Lit(l2)) => {
181                assert_eq!(l1, l2);
182                if let Literal::LitInt(3) = l1 {
183                    // OK
184                } else {
185                    panic!("Expected 3, got {:?}", l1);
186                }
187            }
188            (v1, v2) => panic!("Value mismatch or not Lit: {:?}, {:?}", v1, v2),
189        }
190    }
191
192    #[test]
193    fn test_case_known_lit() {
194        // case 3 of w { LitAlt(1) -> 10; LitAlt(3) -> 30; Default -> 99 }
195        let nodes = vec![
196            CoreFrame::Lit(Literal::LitInt(3)),  // 0
197            CoreFrame::Lit(Literal::LitInt(10)), // 1
198            CoreFrame::Lit(Literal::LitInt(30)), // 2
199            CoreFrame::Lit(Literal::LitInt(99)), // 3
200            CoreFrame::Case {
201                scrutinee: 0,
202                binder: VarId(10),
203                alts: vec![
204                    Alt {
205                        con: AltCon::LitAlt(Literal::LitInt(1)),
206                        binders: vec![],
207                        body: 1,
208                    },
209                    Alt {
210                        con: AltCon::LitAlt(Literal::LitInt(3)),
211                        binders: vec![],
212                        body: 2,
213                    },
214                    Alt {
215                        con: AltCon::Default,
216                        binders: vec![],
217                        body: 3,
218                    },
219                ],
220            }, // 4
221        ];
222        let mut expr = CoreExpr { nodes };
223        let pass = CaseReduce;
224        let changed = pass.run(&mut expr);
225        assert!(changed);
226        // Result should be 30
227        assert!(matches!(
228            expr.nodes[expr.nodes.len() - 1],
229            CoreFrame::Lit(Literal::LitInt(30))
230        ));
231    }
232
233    #[test]
234    fn test_case_known_lit_default() {
235        // case 3 of w { LitAlt(1) -> 10; Default -> 99 }
236        let nodes = vec![
237            CoreFrame::Lit(Literal::LitInt(3)),  // 0
238            CoreFrame::Lit(Literal::LitInt(10)), // 1
239            CoreFrame::Lit(Literal::LitInt(99)), // 2
240            CoreFrame::Case {
241                scrutinee: 0,
242                binder: VarId(10),
243                alts: vec![
244                    Alt {
245                        con: AltCon::LitAlt(Literal::LitInt(1)),
246                        binders: vec![],
247                        body: 1,
248                    },
249                    Alt {
250                        con: AltCon::Default,
251                        binders: vec![],
252                        body: 2,
253                    },
254                ],
255            }, // 3
256        ];
257        let mut expr = CoreExpr { nodes };
258        let pass = CaseReduce;
259        let changed = pass.run(&mut expr);
260        assert!(changed);
261        // Result should be 99
262        assert!(matches!(
263            expr.nodes[expr.nodes.len() - 1],
264            CoreFrame::Lit(Literal::LitInt(99))
265        ));
266    }
267
268    #[test]
269    fn test_case_unknown_untouched() {
270        // case Var(x) of w { Default -> 42 }
271        let nodes = vec![
272            CoreFrame::Var(VarId(1)),            // 0: x
273            CoreFrame::Lit(Literal::LitInt(42)), // 1
274            CoreFrame::Case {
275                scrutinee: 0,
276                binder: VarId(2),
277                alts: vec![Alt {
278                    con: AltCon::Default,
279                    binders: vec![],
280                    body: 1,
281                }],
282            }, // 2
283        ];
284        let mut expr = CoreExpr { nodes };
285        let pass = CaseReduce;
286        let changed = pass.run(&mut expr);
287        assert!(!changed);
288    }
289
290    #[test]
291    fn test_case_binder_substituted() {
292        // case Con(tag=1, [42]) of w { DataAlt(1) [y] -> w }
293        let nodes = vec![
294            CoreFrame::Lit(Literal::LitInt(42)), // 0
295            CoreFrame::Con {
296                tag: DataConId(1),
297                fields: vec![0],
298            }, // 1
299            CoreFrame::Var(VarId(2)),            // 2: w
300            CoreFrame::Case {
301                scrutinee: 1,
302                binder: VarId(2), // w
303                alts: vec![Alt {
304                    con: AltCon::DataAlt(DataConId(1)),
305                    binders: vec![VarId(3)],
306                    body: 2,
307                }],
308            }, // 3
309        ];
310        let mut expr = CoreExpr { nodes };
311        let pass = CaseReduce;
312        let changed = pass.run(&mut expr);
313        assert!(changed);
314        // Result should be Con(tag=1, [42])
315        if let CoreFrame::Con { tag, fields } = &expr.nodes[expr.nodes.len() - 1] {
316            assert_eq!(tag.0, 1);
317            assert_eq!(fields.len(), 1);
318            if let CoreFrame::Lit(Literal::LitInt(42)) = &expr.nodes[fields[0]] {
319                // OK
320            } else {
321                panic!("Expected field to be 42");
322            }
323        } else {
324            panic!("Expected Con, got {:?}", expr.nodes[expr.nodes.len() - 1]);
325        }
326    }
327
328    #[test]
329    fn test_case_reduce_preserves_eval() {
330        // case Con(tag=1, [1, 2]) of w { DataAlt(1) [a, b] -> a + b; Default -> 0 }
331        let nodes = vec![
332            CoreFrame::Lit(Literal::LitInt(1)), // 0
333            CoreFrame::Lit(Literal::LitInt(2)), // 1
334            CoreFrame::Con {
335                tag: DataConId(1),
336                fields: vec![0, 1],
337            }, // 2
338            CoreFrame::Var(VarId(10)),          // 3: a
339            CoreFrame::Var(VarId(11)),          // 4: b
340            CoreFrame::PrimOp {
341                op: PrimOpKind::IntAdd,
342                args: vec![3, 4],
343            }, // 5
344            CoreFrame::Lit(Literal::LitInt(0)), // 6
345            CoreFrame::Case {
346                scrutinee: 2,
347                binder: VarId(12),
348                alts: vec![
349                    Alt {
350                        con: AltCon::DataAlt(DataConId(1)),
351                        binders: vec![VarId(10), VarId(11)],
352                        body: 5,
353                    },
354                    Alt {
355                        con: AltCon::Default,
356                        binders: vec![],
357                        body: 6,
358                    },
359                ],
360            }, // 7
361        ];
362        let mut expr = CoreExpr { nodes };
363        let pass = CaseReduce;
364
365        let mut heap = VecHeap::new();
366        let val_before = tidepool_eval::eval(&expr, &Env::new(), &mut heap).unwrap();
367
368        pass.run(&mut expr);
369
370        let mut heap2 = VecHeap::new();
371        let val_after = tidepool_eval::eval(&expr, &Env::new(), &mut heap2).unwrap();
372
373        match (val_before, val_after) {
374            (Value::Lit(l1), Value::Lit(l2)) => assert_eq!(l1, l2),
375            (Value::Con(t1, f1), Value::Con(t2, f2)) => {
376                assert_eq!(t1, t2);
377                assert_eq!(f1.len(), f2.len());
378                // Simple check for literals in fields
379                for (v1, v2) in f1.iter().zip(f2.iter()) {
380                    if let (Value::Lit(ll1), Value::Lit(ll2)) = (v1, v2) {
381                        assert_eq!(ll1, ll2);
382                    }
383                }
384            }
385            (v1, v2) => panic!(
386                "Value mismatch or unsupported for eval check: {:?}, {:?}",
387                v1, v2
388            ),
389        }
390    }
391}