Skip to main content

oxilean_codegen/opt_copy_prop/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::*;
6
7use super::types::{
8    ConstantFoldReport, CopyProp, CopyPropConfig, DeadBindingElim, DeadBindingReport, InlineReport,
9    InliningPass, InterferenceGraph, OptPipeline, PassKind, RegisterCoalescingHint, UsedVars,
10};
11
12/// Collects register coalescing hints from a copy propagation pass.
13///
14/// For each copy `let x = y`, if `x` and `y` do not interfere in the
15/// interference graph, emit a hint to coalesce them.
16#[allow(dead_code)]
17pub fn collect_coalescing_hints(
18    copies: &[(LcnfVarId, LcnfVarId)],
19    ig: &InterferenceGraph,
20) -> Vec<RegisterCoalescingHint> {
21    let mut hints = Vec::new();
22    for &(src, dst) in copies {
23        let is_safe = !ig.interfere(src, dst);
24        let benefit = if is_safe { 10 } else { 1 };
25        hints.push(RegisterCoalescingHint::new(src, dst, is_safe, benefit));
26    }
27    hints.sort_by(|a, b| b.benefit.cmp(&a.benefit));
28    hints
29}
30#[cfg(test)]
31mod tests {
32    use super::*;
33    use crate::lcnf::{
34        LcnfAlt, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfParam, LcnfType, LcnfVarId,
35    };
36    pub(super) fn make_decl(body: LcnfExpr) -> LcnfFunDecl {
37        LcnfFunDecl {
38            name: "test_fn".to_string(),
39            original_name: None,
40            params: vec![],
41            ret_type: LcnfType::Nat,
42            body,
43            is_recursive: false,
44            is_lifted: false,
45            inline_cost: 1,
46        }
47    }
48    /// `let x = fvar(y); return x`  →  `return y`
49    #[test]
50    pub(super) fn test_simple_fvar_copy() {
51        let body = LcnfExpr::Let {
52            id: LcnfVarId(1),
53            name: "x".to_string(),
54            ty: LcnfType::Nat,
55            value: LcnfLetValue::FVar(LcnfVarId(0)),
56            body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1)))),
57        };
58        let mut decl = make_decl(body);
59        let mut pass = CopyProp::new(CopyPropConfig::default());
60        pass.run(&mut decl);
61        assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))));
62        assert_eq!(pass.report().copies_eliminated, 1);
63    }
64    /// `let x = 42; return x`  →  `return 42`  (fold_literals=true)
65    #[test]
66    pub(super) fn test_literal_fold_enabled() {
67        let body = LcnfExpr::Let {
68            id: LcnfVarId(0),
69            name: "x".to_string(),
70            ty: LcnfType::Nat,
71            value: LcnfLetValue::Lit(LcnfLit::Nat(42)),
72            body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
73        };
74        let mut decl = make_decl(body);
75        let mut pass = CopyProp::new(CopyPropConfig {
76            fold_literals: true,
77            ..Default::default()
78        });
79        pass.run(&mut decl);
80        assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(42))));
81        assert_eq!(pass.report().copies_eliminated, 1);
82    }
83    /// `let x = 42; return x`  stays when `fold_literals=false`.
84    #[test]
85    pub(super) fn test_literal_fold_disabled() {
86        let body = LcnfExpr::Let {
87            id: LcnfVarId(0),
88            name: "x".to_string(),
89            ty: LcnfType::Nat,
90            value: LcnfLetValue::Lit(LcnfLit::Nat(7)),
91            body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
92        };
93        let mut decl = make_decl(body);
94        let mut pass = CopyProp::new(CopyPropConfig {
95            fold_literals: false,
96            ..Default::default()
97        });
98        pass.run(&mut decl);
99        assert!(matches!(decl.body, LcnfExpr::Let { .. }));
100        assert_eq!(pass.report().copies_eliminated, 0);
101    }
102    /// Transitive chain: `a=b, b=c, return a`  →  `return c`  (2 hops)
103    #[test]
104    pub(super) fn test_transitive_chain() {
105        let body = LcnfExpr::Let {
106            id: LcnfVarId(1),
107            name: "b".to_string(),
108            ty: LcnfType::Nat,
109            value: LcnfLetValue::FVar(LcnfVarId(0)),
110            body: Box::new(LcnfExpr::Let {
111                id: LcnfVarId(2),
112                name: "a".to_string(),
113                ty: LcnfType::Nat,
114                value: LcnfLetValue::FVar(LcnfVarId(1)),
115                body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(2)))),
116            }),
117        };
118        let mut decl = make_decl(body);
119        let mut pass = CopyProp::new(CopyPropConfig::default());
120        pass.run(&mut decl);
121        assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))));
122        assert_eq!(pass.report().copies_eliminated, 2);
123        assert_eq!(pass.report().chains_followed, 1);
124    }
125    /// Chain depth limit: with max_chain_depth=1 the second hop is not followed.
126    #[test]
127    pub(super) fn test_chain_depth_limit() {
128        let body = LcnfExpr::Let {
129            id: LcnfVarId(1),
130            name: "b".to_string(),
131            ty: LcnfType::Nat,
132            value: LcnfLetValue::FVar(LcnfVarId(0)),
133            body: Box::new(LcnfExpr::Let {
134                id: LcnfVarId(2),
135                name: "a".to_string(),
136                ty: LcnfType::Nat,
137                value: LcnfLetValue::FVar(LcnfVarId(1)),
138                body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(2)))),
139            }),
140        };
141        let mut decl = make_decl(body);
142        let mut pass = CopyProp::new(CopyPropConfig {
143            max_chain_depth: 1,
144            fold_literals: true,
145        });
146        pass.run(&mut decl);
147        assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))));
148    }
149    /// App bindings are NOT propagated (conservative aliasing).
150    #[test]
151    pub(super) fn test_app_not_propagated() {
152        let body = LcnfExpr::Let {
153            id: LcnfVarId(1),
154            name: "r".to_string(),
155            ty: LcnfType::Nat,
156            value: LcnfLetValue::App(LcnfArg::Var(LcnfVarId(0)), vec![]),
157            body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1)))),
158        };
159        let mut decl = make_decl(body);
160        let mut pass = CopyProp::default_pass();
161        pass.run(&mut decl);
162        assert!(matches!(decl.body, LcnfExpr::Let { .. }));
163        assert_eq!(pass.report().copies_eliminated, 0);
164    }
165    /// Copy propagation inside case branches is independent per branch.
166    #[test]
167    pub(super) fn test_copy_in_case_branches() {
168        let case_expr = LcnfExpr::Case {
169            scrutinee: LcnfVarId(0),
170            scrutinee_ty: LcnfType::Object,
171            alts: vec![LcnfAlt {
172                ctor_name: "A".to_string(),
173                ctor_tag: 0,
174                params: vec![LcnfParam {
175                    id: LcnfVarId(1),
176                    name: "p".to_string(),
177                    ty: LcnfType::Nat,
178                    erased: false,
179                    borrowed: false,
180                }],
181                body: LcnfExpr::Let {
182                    id: LcnfVarId(2),
183                    name: "q".to_string(),
184                    ty: LcnfType::Nat,
185                    value: LcnfLetValue::FVar(LcnfVarId(1)),
186                    body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(2)))),
187                },
188            }],
189            default: Some(Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))))),
190        };
191        let mut decl = make_decl(case_expr);
192        let mut pass = CopyProp::default_pass();
193        pass.run(&mut decl);
194        match &decl.body {
195            LcnfExpr::Case { alts, .. } => {
196                assert_eq!(alts.len(), 1);
197                assert_eq!(alts[0].body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1))));
198            }
199            _ => panic!("Expected Case"),
200        }
201        assert!(pass.report().copies_eliminated >= 1);
202    }
203    /// Erased bindings are treated as copies and propagated.
204    #[test]
205    pub(super) fn test_erased_copy_propagated() {
206        let body = LcnfExpr::Let {
207            id: LcnfVarId(0),
208            name: "e".to_string(),
209            ty: LcnfType::Erased,
210            value: LcnfLetValue::Erased,
211            body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
212        };
213        let mut decl = make_decl(body);
214        let mut pass = CopyProp::default_pass();
215        pass.run(&mut decl);
216        assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Erased));
217        assert_eq!(pass.report().copies_eliminated, 1);
218    }
219}
220#[allow(dead_code)]
221pub(super) fn collect_used(expr: &LcnfExpr, used: &mut UsedVars) {
222    match expr {
223        LcnfExpr::Return(arg) => collect_used_arg(arg, used),
224        LcnfExpr::Let { value, body, .. } => {
225            collect_used_value(value, used);
226            collect_used(body, used);
227        }
228        LcnfExpr::Case {
229            scrutinee,
230            scrutinee_ty: _,
231            alts,
232            default,
233            ..
234        } => {
235            used.vars.insert(*scrutinee);
236            for alt in alts {
237                collect_used(&alt.body, used);
238            }
239            if let Some(d) = default {
240                collect_used(d, used);
241            }
242        }
243        LcnfExpr::TailCall(callee, args) => {
244            collect_used_arg(callee, used);
245            for a in args {
246                collect_used_arg(a, used);
247            }
248        }
249        LcnfExpr::Unreachable => {}
250    }
251}
252#[allow(dead_code)]
253pub(super) fn collect_used_arg(arg: &LcnfArg, used: &mut UsedVars) {
254    if let LcnfArg::Var(id) = arg {
255        used.vars.insert(*id);
256    }
257}
258#[allow(dead_code)]
259pub(super) fn collect_used_value(val: &LcnfLetValue, used: &mut UsedVars) {
260    match val {
261        LcnfLetValue::FVar(id) => {
262            used.vars.insert(*id);
263        }
264        LcnfLetValue::App(callee, args) => {
265            collect_used_arg(callee, used);
266            for a in args {
267                collect_used_arg(a, used);
268            }
269        }
270        LcnfLetValue::Ctor(_, _, args) => {
271            for a in args {
272                collect_used_arg(a, used);
273            }
274        }
275        LcnfLetValue::Proj(_, _, id) => {
276            used.vars.insert(*id);
277        }
278        _ => {}
279    }
280}
281/// Inlining cost threshold (max `inline_cost` to be inlined).
282#[allow(dead_code)]
283pub const DEFAULT_INLINE_THRESHOLD: u32 = 5;
284/// Count the number of `let`-bindings in an expression tree.
285#[allow(dead_code)]
286pub fn count_let_bindings(expr: &LcnfExpr) -> usize {
287    match expr {
288        LcnfExpr::Let { body, .. } => 1 + count_let_bindings(body),
289        LcnfExpr::Case { alts, default, .. } => {
290            let alt_sum: usize = alts.iter().map(|a| count_let_bindings(&a.body)).sum();
291            let def_sum = default.as_ref().map_or(0, |d| count_let_bindings(d));
292            alt_sum + def_sum
293        }
294        _ => 0,
295    }
296}
297/// Return the maximum nesting depth of a LCNF expression.
298#[allow(dead_code)]
299pub fn expr_depth(expr: &LcnfExpr) -> usize {
300    match expr {
301        LcnfExpr::Let { body, .. } => 1 + expr_depth(body),
302        LcnfExpr::Case { alts, default, .. } => {
303            let alt_max = alts.iter().map(|a| expr_depth(&a.body)).max().unwrap_or(0);
304            let def_max = default.as_ref().map_or(0, |d| expr_depth(d));
305            1 + alt_max.max(def_max)
306        }
307        LcnfExpr::TailCall(_, args) => args.len(),
308        _ => 0,
309    }
310}
311/// Check whether an expression contains any tail calls to the given id.
312#[allow(dead_code)]
313pub fn has_tail_call_to(expr: &LcnfExpr, target: LcnfVarId) -> bool {
314    match expr {
315        LcnfExpr::TailCall(LcnfArg::Var(id), _) => *id == target,
316        LcnfExpr::Let { body, .. } => has_tail_call_to(body, target),
317        LcnfExpr::Case { alts, default, .. } => {
318            alts.iter().any(|a| has_tail_call_to(&a.body, target))
319                || default
320                    .as_ref()
321                    .is_some_and(|d| has_tail_call_to(d, target))
322        }
323        _ => false,
324    }
325}
326/// Collect all `LcnfVarId`s bound by `let` in an expression.
327#[allow(dead_code)]
328pub fn collect_bound_vars(expr: &LcnfExpr, out: &mut Vec<LcnfVarId>) {
329    match expr {
330        LcnfExpr::Let { id, body, .. } => {
331            out.push(*id);
332            collect_bound_vars(body, out);
333        }
334        LcnfExpr::Case { alts, default, .. } => {
335            for alt in alts {
336                collect_bound_vars(&alt.body, out);
337            }
338            if let Some(d) = default {
339                collect_bound_vars(d, out);
340            }
341        }
342        _ => {}
343    }
344}
345#[cfg(test)]
346mod tests_extended {
347    use super::*;
348    pub(super) fn make_var(n: u32) -> LcnfVarId {
349        LcnfVarId(u64::from(n))
350    }
351    pub(super) fn make_simple_decl(body: LcnfExpr) -> LcnfFunDecl {
352        LcnfFunDecl {
353            name: "test_fn".to_string(),
354            original_name: None,
355            params: vec![],
356            ret_type: LcnfType::Nat,
357            body,
358            is_recursive: false,
359            is_lifted: false,
360            inline_cost: 1,
361        }
362    }
363    #[test]
364    pub(super) fn test_dead_binding_removal() {
365        let body = LcnfExpr::Let {
366            id: LcnfVarId(99),
367            name: "x".to_string(),
368            ty: LcnfType::Nat,
369            value: LcnfLetValue::Lit(LcnfLit::Nat(42)),
370            body: Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
371        };
372        let mut decl = make_simple_decl(body);
373        let mut pass = DeadBindingElim::default_pass();
374        pass.run(&mut decl);
375        assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))));
376        assert!(pass.report().bindings_removed >= 0);
377    }
378    #[test]
379    pub(super) fn test_count_let_bindings() {
380        let body = LcnfExpr::Let {
381            id: LcnfVarId(0),
382            name: "a".to_string(),
383            ty: LcnfType::Nat,
384            value: LcnfLetValue::Lit(LcnfLit::Nat(1)),
385            body: Box::new(LcnfExpr::Let {
386                id: LcnfVarId(1),
387                name: "b".to_string(),
388                ty: LcnfType::Nat,
389                value: LcnfLetValue::Lit(LcnfLit::Nat(2)),
390                body: Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
391            }),
392        };
393        assert_eq!(count_let_bindings(&body), 2);
394    }
395    #[test]
396    pub(super) fn test_expr_depth() {
397        let body = LcnfExpr::Let {
398            id: LcnfVarId(0),
399            name: "a".to_string(),
400            ty: LcnfType::Nat,
401            value: LcnfLetValue::Lit(LcnfLit::Nat(0)),
402            body: Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
403        };
404        assert_eq!(expr_depth(&body), 1);
405    }
406    #[test]
407    pub(super) fn test_has_tail_call_to() {
408        let target = make_var(7);
409        let body = LcnfExpr::TailCall(LcnfArg::Var(target), vec![]);
410        assert!(has_tail_call_to(&body, target));
411        assert!(!has_tail_call_to(&body, make_var(8)));
412    }
413    #[test]
414    pub(super) fn test_collect_bound_vars() {
415        let body = LcnfExpr::Let {
416            id: LcnfVarId(5),
417            name: "x".to_string(),
418            ty: LcnfType::Nat,
419            value: LcnfLetValue::Lit(LcnfLit::Nat(0)),
420            body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(5)))),
421        };
422        let mut bound = vec![];
423        collect_bound_vars(&body, &mut bound);
424        assert_eq!(bound, vec![LcnfVarId(5)]);
425    }
426    #[test]
427    pub(super) fn test_opt_pipeline_default() {
428        let body = LcnfExpr::Let {
429            id: LcnfVarId(0),
430            name: "x".to_string(),
431            ty: LcnfType::Nat,
432            value: LcnfLetValue::FVar(LcnfVarId(1)),
433            body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
434        };
435        let mut decl = make_simple_decl(body);
436        decl.params.push(LcnfParam {
437            id: LcnfVarId(1),
438            name: "n".to_string(),
439            ty: LcnfType::Nat,
440            erased: false,
441            borrowed: false,
442        });
443        let mut pipeline = OptPipeline::new();
444        let result = pipeline.run(&mut decl);
445        assert!(result.copy_prop.copies_eliminated >= 1);
446    }
447    #[test]
448    pub(super) fn test_pass_kind_display() {
449        assert_eq!(PassKind::CopyProp.to_string(), "CopyProp");
450        assert_eq!(PassKind::DeadBinding.to_string(), "DeadBinding");
451        assert_eq!(PassKind::ConstantFold.to_string(), "ConstantFold");
452        assert_eq!(PassKind::Inlining.to_string(), "Inlining");
453    }
454    #[test]
455    pub(super) fn test_inline_candidate() {
456        let pass = InliningPass::default_pass();
457        let cheap = LcnfFunDecl {
458            name: "cheap".to_string(),
459            original_name: None,
460            params: vec![],
461            ret_type: LcnfType::Nat,
462            body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
463            is_recursive: false,
464            is_lifted: false,
465            inline_cost: 1,
466        };
467        let expensive = LcnfFunDecl {
468            inline_cost: 100,
469            name: "expensive".to_string(),
470            ..cheap.clone()
471        };
472        assert!(pass.is_inline_candidate(&cheap));
473        assert!(!pass.is_inline_candidate(&expensive));
474    }
475    #[test]
476    pub(super) fn test_dead_binding_config_display() {
477        let cfg = CopyPropConfig::default();
478        let s = format!("{}", cfg);
479        assert!(s.contains("CopyPropConfig"));
480    }
481    #[test]
482    pub(super) fn test_dead_binding_report_display() {
483        let r = DeadBindingReport {
484            bindings_removed: 3,
485            passes_run: 2,
486        };
487        let s = format!("{}", r);
488        assert!(s.contains("removed=3"));
489        assert!(s.contains("passes=2"));
490    }
491    #[test]
492    pub(super) fn test_constant_fold_report_display() {
493        let r = ConstantFoldReport { folds_performed: 7 };
494        let s = format!("{}", r);
495        assert!(s.contains("folds=7"));
496    }
497    #[test]
498    pub(super) fn test_inline_report_display() {
499        let r = InlineReport {
500            inlines_performed: 2,
501            functions_considered: 10,
502        };
503        let s = format!("{}", r);
504        assert!(s.contains("inlined=2"));
505        assert!(s.contains("considered=10"));
506    }
507}