Skip to main content

tidepool_optimize/
inline.rs

1use crate::occ::{get_occ, occ_analysis, Occ};
2use tidepool_eval::{Changed, Pass};
3use tidepool_repr::{get_children, replace_subtree, CoreExpr, CoreFrame};
4
5/// Inlining pass: eliminates single-use `LetNonRec` bindings by substituting the RHS directly at the use site.
6pub struct Inline;
7
8impl Pass for Inline {
9    fn run(&self, expr: &mut CoreExpr) -> Changed {
10        if expr.nodes.is_empty() {
11            return false;
12        }
13        let occ_map = occ_analysis(expr);
14        match try_inline(expr, &occ_map) {
15            Some(new_expr) => {
16                *expr = new_expr;
17                true
18            }
19            None => false,
20        }
21    }
22
23    fn name(&self) -> &str {
24        "Inline"
25    }
26}
27
28fn try_inline(expr: &CoreExpr, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
29    try_inline_at(expr, expr.nodes.len() - 1, occ_map)
30}
31
32fn try_inline_at(expr: &CoreExpr, idx: usize, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
33    match &expr.nodes[idx] {
34        CoreFrame::LetNonRec { binder, rhs, body } => {
35            if get_occ(occ_map, *binder) == Occ::Once {
36                // Inline: substitute binder -> rhs in body
37                let body_tree = expr.extract_subtree(*body);
38                let rhs_tree = expr.extract_subtree(*rhs);
39                let inlined = tidepool_repr::subst::subst(&body_tree, *binder, &rhs_tree);
40                Some(replace_subtree(expr, idx, &inlined))
41            } else {
42                // Try children
43                try_inline_at(expr, *rhs, occ_map).or_else(|| try_inline_at(expr, *body, occ_map))
44            }
45        }
46        // Never inline LetRec, even if Once (it might be recursive via own RHS)
47        _ => try_children(expr, idx, occ_map),
48    }
49}
50
51fn try_children(expr: &CoreExpr, idx: usize, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
52    let children = get_children(&expr.nodes[idx]);
53    for child in children {
54        if let Some(result) = try_inline_at(expr, child, occ_map) {
55            return Some(result);
56        }
57    }
58    None
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use tidepool_eval::{eval, Env, VecHeap};
65    use tidepool_repr::{Literal, PrimOpKind, VarId};
66
67    fn tree(nodes: Vec<CoreFrame<usize>>) -> CoreExpr {
68        CoreExpr { nodes }
69    }
70
71    // 1. let x = 42 in x -> 42. Binder Once, inlined.
72    #[test]
73    fn test_inline_single_use() {
74        let x = VarId(1);
75        let mut expr = tree(vec![
76            CoreFrame::Lit(Literal::LitInt(42)), // 0
77            CoreFrame::Var(x),                   // 1
78            CoreFrame::LetNonRec {
79                binder: x,
80                rhs: 0,
81                body: 1,
82            }, // 2
83        ]);
84        let pass = Inline;
85        let changed = pass.run(&mut expr);
86        assert!(changed);
87        assert_eq!(expr.nodes.len(), 1);
88        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
89    }
90
91    // 2. let x = 42 in x + x -> unchanged. Binder Many, not inlined.
92    #[test]
93    fn test_inline_multi_use_preserved() {
94        let x = VarId(1);
95        let mut expr = tree(vec![
96            CoreFrame::Lit(Literal::LitInt(42)), // 0
97            CoreFrame::Var(x),                   // 1
98            CoreFrame::Var(x),                   // 2
99            CoreFrame::PrimOp {
100                op: PrimOpKind::IntAdd,
101                args: vec![1, 2],
102            }, // 3
103            CoreFrame::LetNonRec {
104                binder: x,
105                rhs: 0,
106                body: 3,
107            }, // 4
108        ]);
109        let pass = Inline;
110        let changed = pass.run(&mut expr);
111        assert!(!changed);
112    }
113
114    // 3. let x = 42 in 0 -> unchanged by inline (DCE will handle dead bindings).
115    #[test]
116    fn test_inline_dead_preserved() {
117        let x = VarId(1);
118        let mut expr = tree(vec![
119            CoreFrame::Lit(Literal::LitInt(42)), // 0
120            CoreFrame::Lit(Literal::LitInt(0)),  // 1
121            CoreFrame::LetNonRec {
122                binder: x,
123                rhs: 0,
124                body: 1,
125            }, // 2
126        ]);
127        let pass = Inline;
128        let changed = pass.run(&mut expr);
129        assert!(!changed);
130    }
131
132    // 4. let x = 1 in let y = x in y -> after two passes: 1.
133    #[test]
134    fn test_inline_nested() {
135        let x = VarId(1);
136        let y = VarId(2);
137        let mut expr = tree(vec![
138            CoreFrame::Lit(Literal::LitInt(1)), // 0
139            CoreFrame::Var(x),                  // 1
140            CoreFrame::Var(y),                  // 2
141            CoreFrame::LetNonRec {
142                binder: y,
143                rhs: 1,
144                body: 2,
145            }, // 3
146            CoreFrame::LetNonRec {
147                binder: x,
148                rhs: 0,
149                body: 3,
150            }, // 4
151        ]);
152        let pass = Inline;
153
154        // Pass 1: inline x = 1 (outer let), producing: let y = 1 in y
155        assert!(pass.run(&mut expr));
156        // Pass 2: inline y = 1 (inner let), producing: 1
157        assert!(pass.run(&mut expr));
158        // Result should be 1
159        assert_eq!(expr.nodes.len(), 1);
160        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(1)));
161    }
162
163    // 5. letrec f = f in f -> unchanged. LetRec binder Once but must NOT inline.
164    #[test]
165    fn test_inline_letrec_not_inlined() {
166        let f = VarId(1);
167        let mut expr = tree(vec![
168            CoreFrame::Var(f), // 0
169            CoreFrame::Var(f), // 1
170            CoreFrame::LetRec {
171                bindings: vec![(f, 0)],
172                body: 1,
173            }, // 2
174        ]);
175        let pass = Inline;
176        let changed = pass.run(&mut expr);
177        assert!(!changed);
178    }
179
180    // 6. let x = y in \y. x -> \y'. y (fresh y').
181    #[test]
182    fn test_inline_capture_avoiding() {
183        let x = VarId(1);
184        let y = VarId(2);
185        let mut expr = tree(vec![
186            CoreFrame::Var(y),                     // 0: rhs
187            CoreFrame::Var(x),                     // 1
188            CoreFrame::Lam { binder: y, body: 1 }, // 2: body
189            CoreFrame::LetNonRec {
190                binder: x,
191                rhs: 0,
192                body: 2,
193            }, // 3
194        ]);
195        let pass = Inline;
196        let changed = pass.run(&mut expr);
197        assert!(changed);
198
199        // Result should be \y'. y
200        let root = expr.nodes.len() - 1;
201        if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
202            assert_ne!(*binder, y);
203            if let CoreFrame::Var(v) = &expr.nodes[*body] {
204                assert_eq!(*v, y);
205            } else {
206                panic!("Body should be Var(y)");
207            }
208        } else {
209            panic!("Result should be Lam");
210        }
211    }
212
213    // 7. test_inline_preserves_eval: Build let x = 21 in x + x (Many, no inline) and let x = 21 in x (Once, inline). Eval before/after, verify match.
214    #[test]
215    fn test_inline_preserves_eval() {
216        let x = VarId(1);
217
218        // Case A: Once (should inline)
219        let expr_once = tree(vec![
220            CoreFrame::Lit(Literal::LitInt(21)),
221            CoreFrame::Var(x),
222            CoreFrame::LetNonRec {
223                binder: x,
224                rhs: 0,
225                body: 1,
226            },
227        ]);
228        let mut expr_once_reduced = expr_once.clone();
229        Inline.run(&mut expr_once_reduced);
230
231        let mut heap = VecHeap::new();
232        let env = Env::new();
233        let v1 = eval(&expr_once, &env, &mut heap).unwrap();
234        let v2 = eval(&expr_once_reduced, &env, &mut heap).unwrap();
235        match (v1, v2) {
236            (tidepool_eval::Value::Lit(l1), tidepool_eval::Value::Lit(l2)) => assert_eq!(l1, l2),
237            _ => panic!("Expected literals"),
238        }
239
240        // Case B: Many (should NOT inline)
241        let mut expr_many = tree(vec![
242            CoreFrame::Lit(Literal::LitInt(21)),
243            CoreFrame::Var(x),
244            CoreFrame::Var(x),
245            CoreFrame::PrimOp {
246                op: PrimOpKind::IntAdd,
247                args: vec![1, 2],
248            },
249            CoreFrame::LetNonRec {
250                binder: x,
251                rhs: 0,
252                body: 3,
253            },
254        ]);
255        let expr_many_orig = expr_many.clone();
256        Inline.run(&mut expr_many);
257        assert_eq!(expr_many, expr_many_orig);
258    }
259}