Skip to main content

oxilean_codegen/opt_passes/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::{LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfVarId};
6use std::collections::{HashMap, HashSet};
7
8use super::types::{
9    BetaReductionPass, ConstantFoldingPass, CopyPropagationPass, DeadCodeEliminationPass,
10    ExprSizeEstimator, IdentityEliminationPass, InlineCostEstimator, OPAnalysisCache,
11    OPConstantFoldingHelper, OPDepGraph, OPDominatorTree, OPLivenessInfo, OPPassConfig,
12    OPPassPhase, OPPassRegistry, OPPassStats, OPWorklist, PassDependency, PassManager, PassStats,
13    PgoHints, StrengthReductionPass, UnreachableCodeEliminationPass,
14};
15use std::fmt;
16
17/// Trait for optimization passes that operate on LCNF function declarations.
18pub trait OptPass: fmt::Debug {
19    /// Human-readable name of this pass.
20    fn name(&self) -> &str;
21    /// Run the pass on a set of declarations, returning the number of changes made.
22    fn run_pass(&mut self, decls: &mut [LcnfFunDecl]) -> usize;
23    /// Whether this pass is enabled.
24    fn is_enabled(&self) -> bool {
25        true
26    }
27    /// Dependencies: names of passes that must run before this one.
28    fn dependencies(&self) -> Vec<&str> {
29        Vec::new()
30    }
31}
32/// Replace all occurrences of variable `from` with `to` in `expr`.
33pub fn substitute_var_in_expr(expr: &mut LcnfExpr, from: LcnfVarId, to: LcnfVarId) {
34    let subst_arg = |a: &mut LcnfArg| {
35        if let LcnfArg::Var(v) = a {
36            if *v == from {
37                *v = to;
38            }
39        }
40    };
41    let subst_value = |val: &mut LcnfLetValue| match val {
42        LcnfLetValue::App(f, args) => {
43            subst_arg(f);
44            for a in args {
45                subst_arg(a);
46            }
47        }
48        LcnfLetValue::FVar(v) => {
49            if *v == from {
50                *v = to;
51            }
52        }
53        LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
54            for a in args {
55                subst_arg(a);
56            }
57        }
58        LcnfLetValue::Proj(_, _, v) => {
59            if *v == from {
60                *v = to;
61            }
62        }
63        LcnfLetValue::Reset(v) => {
64            if *v == from {
65                *v = to;
66            }
67        }
68        LcnfLetValue::Lit(_) | LcnfLetValue::Erased => {}
69    };
70    match expr {
71        LcnfExpr::Let { value, body, .. } => {
72            subst_value(value);
73            substitute_var_in_expr(body, from, to);
74        }
75        LcnfExpr::Case {
76            scrutinee,
77            alts,
78            default,
79            ..
80        } => {
81            if *scrutinee == from {
82                *scrutinee = to;
83            }
84            for alt in alts.iter_mut() {
85                substitute_var_in_expr(&mut alt.body, from, to);
86            }
87            if let Some(def) = default {
88                substitute_var_in_expr(def, from, to);
89            }
90        }
91        LcnfExpr::Return(a) => subst_arg(a),
92        LcnfExpr::TailCall(f, args) => {
93            subst_arg(f);
94            for a in args {
95                subst_arg(a);
96            }
97        }
98        LcnfExpr::Unreachable => {}
99    }
100}
101/// Run all optimization passes in sequence.
102pub fn run_all_passes(_decls: &mut Vec<LcnfFunDecl>, pgo: Option<&PgoHints>) {
103    let mut _dce = DeadCodeEliminationPass::new();
104    let mut _cp = CopyPropagationPass::new();
105    let mut _cf = ConstantFoldingPass::new();
106    let mut _beta = BetaReductionPass::new();
107    let mut _identity = IdentityEliminationPass::new();
108    let mut _unreachable = UnreachableCodeEliminationPass::new();
109    let _ = pgo;
110}
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::lcnf::{LcnfLit, LcnfType};
115    pub(super) fn vid(n: u64) -> LcnfVarId {
116        LcnfVarId(n)
117    }
118    pub(super) fn mk_fun_decl(name: &str, body: LcnfExpr) -> LcnfFunDecl {
119        LcnfFunDecl {
120            name: name.to_string(),
121            original_name: None,
122            params: vec![],
123            ret_type: LcnfType::Nat,
124            body,
125            is_recursive: false,
126            is_lifted: false,
127            inline_cost: 0,
128        }
129    }
130    pub(super) fn mk_let(id: u64, value: LcnfLetValue, body: LcnfExpr) -> LcnfExpr {
131        LcnfExpr::Let {
132            id: vid(id),
133            name: format!("x{}", id),
134            ty: LcnfType::Nat,
135            value,
136            body: Box::new(body),
137        }
138    }
139    #[test]
140    pub(super) fn test_constant_folding_pass() {
141        let mut pass = ConstantFoldingPass::new();
142        assert_eq!(pass.folds_performed, 0);
143        assert_eq!(pass.try_fold_nat_op("add", 3, 4), Some(7));
144        assert_eq!(pass.try_fold_nat_op("sub", 5, 3), Some(2));
145        assert_eq!(pass.try_fold_nat_op("mul", 2, 6), Some(12));
146        assert_eq!(pass.try_fold_nat_op("div", 10, 2), Some(5));
147        assert_eq!(pass.try_fold_nat_op("div", 10, 0), None);
148        assert_eq!(pass.try_fold_nat_op("mod", 10, 3), Some(1));
149        assert_eq!(pass.try_fold_nat_op("mod", 10, 0), None);
150        assert_eq!(pass.try_fold_nat_op("min", 3, 7), Some(3));
151        assert_eq!(pass.try_fold_nat_op("max", 3, 7), Some(7));
152        assert_eq!(pass.try_fold_nat_op("pow", 2, 10), Some(1024));
153        assert_eq!(pass.try_fold_nat_op("and", 0xFF, 0x0F), Some(0x0F));
154        assert_eq!(pass.try_fold_nat_op("or", 0xF0, 0x0F), Some(0xFF));
155        assert_eq!(pass.try_fold_nat_op("xor", 0xFF, 0xFF), Some(0));
156        assert_eq!(pass.try_fold_nat_op("shl", 1, 3), Some(8));
157        assert_eq!(pass.try_fold_nat_op("shr", 16, 2), Some(4));
158        assert_eq!(pass.try_fold_nat_op("unknown", 1, 2), None);
159    }
160    #[test]
161    pub(super) fn test_constant_folding_bool_ops() {
162        let pass = ConstantFoldingPass::new();
163        assert_eq!(pass.try_fold_bool_op("and", true, false), Some(false));
164        assert_eq!(pass.try_fold_bool_op("or", true, false), Some(true));
165        assert_eq!(pass.try_fold_bool_op("xor", true, true), Some(false));
166        assert_eq!(pass.try_fold_bool_op("eq", true, true), Some(true));
167        assert_eq!(pass.try_fold_bool_op("ne", true, false), Some(true));
168        assert_eq!(pass.try_fold_bool_op("bad", true, false), None);
169    }
170    #[test]
171    pub(super) fn test_constant_folding_cmp_ops() {
172        let pass = ConstantFoldingPass::new();
173        assert_eq!(pass.try_fold_cmp("eq", 5, 5), Some(true));
174        assert_eq!(pass.try_fold_cmp("ne", 5, 5), Some(false));
175        assert_eq!(pass.try_fold_cmp("lt", 3, 5), Some(true));
176        assert_eq!(pass.try_fold_cmp("le", 5, 5), Some(true));
177        assert_eq!(pass.try_fold_cmp("gt", 5, 3), Some(true));
178        assert_eq!(pass.try_fold_cmp("ge", 3, 5), Some(false));
179        assert_eq!(pass.try_fold_cmp("bad", 1, 2), None);
180    }
181    #[test]
182    pub(super) fn test_constant_folding_run() {
183        let mut pass = ConstantFoldingPass::new();
184        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(42)));
185        let mut decls = vec![mk_fun_decl("f", body)];
186        pass.run(&mut decls);
187        assert_eq!(pass.folds_performed, 0);
188    }
189    #[test]
190    pub(super) fn test_constant_folding_debug() {
191        let pass = ConstantFoldingPass::new();
192        let s = format!("{:?}", pass);
193        assert!(s.contains("ConstantFoldingPass"));
194    }
195    #[test]
196    pub(super) fn test_dead_code_elimination_pass() {
197        let mut pass = DeadCodeEliminationPass::new();
198        assert_eq!(pass.removed, 0);
199        let body = mk_let(
200            0,
201            LcnfLetValue::Lit(LcnfLit::Nat(42)),
202            mk_let(
203                1,
204                LcnfLetValue::Lit(LcnfLit::Nat(99)),
205                LcnfExpr::Return(LcnfArg::Var(vid(1))),
206            ),
207        );
208        let mut decls = vec![mk_fun_decl("f", body)];
209        pass.run(&mut decls);
210        assert!(pass.removed > 0, "expected dead let to be removed");
211    }
212    #[test]
213    pub(super) fn test_dead_code_elimination_debug() {
214        let pass = DeadCodeEliminationPass::new();
215        let s = format!("{:?}", pass);
216        assert!(s.contains("DeadCodeEliminationPass"));
217    }
218    #[test]
219    pub(super) fn test_copy_propagation_pass() {
220        let mut pass = CopyPropagationPass::new();
221        assert_eq!(pass.substitutions, 0);
222        let body = mk_let(
223            1,
224            LcnfLetValue::FVar(vid(0)),
225            LcnfExpr::Return(LcnfArg::Var(vid(1))),
226        );
227        let mut decls = vec![mk_fun_decl("f", body)];
228        pass.run(&mut decls);
229        assert!(pass.substitutions > 0, "expected copy to be propagated");
230    }
231    #[test]
232    pub(super) fn test_copy_propagation_debug() {
233        let pass = CopyPropagationPass::new();
234        let s = format!("{:?}", pass);
235        assert!(s.contains("CopyPropagationPass"));
236    }
237    #[test]
238    pub(super) fn test_beta_reduction_pass() {
239        let mut pass = BetaReductionPass::new();
240        assert_eq!(pass.reductions, 0);
241        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
242        let mut decls = vec![mk_fun_decl("f", body)];
243        pass.run(&mut decls);
244        assert_eq!(pass.reductions, 0);
245        let body2 = LcnfExpr::TailCall(LcnfArg::Lit(LcnfLit::Nat(0)), vec![]);
246        let mut decls2 = vec![mk_fun_decl("g", body2)];
247        pass.run(&mut decls2);
248        assert_eq!(pass.reductions, 1);
249    }
250    #[test]
251    pub(super) fn test_beta_reduction_debug() {
252        let pass = BetaReductionPass::new();
253        let s = format!("{:?}", pass);
254        assert!(s.contains("BetaReductionPass"));
255    }
256    #[test]
257    pub(super) fn test_identity_elimination() {
258        let mut pass = IdentityEliminationPass::new();
259        let body = mk_let(
260            0,
261            LcnfLetValue::FVar(vid(0)),
262            LcnfExpr::Return(LcnfArg::Var(vid(0))),
263        );
264        let mut decls = vec![mk_fun_decl("f", body)];
265        pass.run(&mut decls);
266        assert_eq!(pass.eliminated, 1);
267        assert!(matches!(decls[0].body, LcnfExpr::Return(_)));
268    }
269    #[test]
270    pub(super) fn test_identity_elimination_no_self_ref() {
271        let mut pass = IdentityEliminationPass::new();
272        let body = mk_let(
273            1,
274            LcnfLetValue::FVar(vid(0)),
275            LcnfExpr::Return(LcnfArg::Var(vid(1))),
276        );
277        let mut decls = vec![mk_fun_decl("f", body)];
278        pass.run(&mut decls);
279        assert_eq!(pass.eliminated, 0);
280    }
281    #[test]
282    pub(super) fn test_strength_reduction_power_of_two() {
283        assert!(StrengthReductionPass::is_power_of_two(1));
284        assert!(StrengthReductionPass::is_power_of_two(2));
285        assert!(StrengthReductionPass::is_power_of_two(4));
286        assert!(StrengthReductionPass::is_power_of_two(1024));
287        assert!(!StrengthReductionPass::is_power_of_two(0));
288        assert!(!StrengthReductionPass::is_power_of_two(3));
289        assert!(!StrengthReductionPass::is_power_of_two(6));
290    }
291    #[test]
292    pub(super) fn test_strength_reduction_log2() {
293        assert_eq!(StrengthReductionPass::log2_exact(1), Some(0));
294        assert_eq!(StrengthReductionPass::log2_exact(2), Some(1));
295        assert_eq!(StrengthReductionPass::log2_exact(8), Some(3));
296        assert_eq!(StrengthReductionPass::log2_exact(1024), Some(10));
297        assert_eq!(StrengthReductionPass::log2_exact(0), None);
298        assert_eq!(StrengthReductionPass::log2_exact(3), None);
299    }
300    #[test]
301    pub(super) fn test_strength_reduction_is_mask() {
302        assert!(StrengthReductionPass::is_mask(1));
303        assert!(StrengthReductionPass::is_mask(3));
304        assert!(StrengthReductionPass::is_mask(7));
305        assert!(StrengthReductionPass::is_mask(0xFF));
306        assert!(!StrengthReductionPass::is_mask(0));
307        assert!(!StrengthReductionPass::is_mask(5));
308    }
309    #[test]
310    pub(super) fn test_strength_reduction_bit_ops() {
311        assert_eq!(StrengthReductionPass::ctz(8), 3);
312        assert_eq!(StrengthReductionPass::ctz(0), 64);
313        assert_eq!(StrengthReductionPass::clz(1), 63);
314        assert_eq!(StrengthReductionPass::popcount(0xFF), 8);
315        assert_eq!(StrengthReductionPass::popcount(0), 0);
316    }
317    #[test]
318    pub(super) fn test_unreachable_code_elimination() {
319        let mut pass = UnreachableCodeEliminationPass::new();
320        let body = mk_let(
321            0,
322            LcnfLetValue::Lit(LcnfLit::Nat(42)),
323            LcnfExpr::Unreachable,
324        );
325        let mut decls = vec![mk_fun_decl("f", body)];
326        pass.run(&mut decls);
327        assert_eq!(pass.eliminated, 1);
328        assert!(matches!(decls[0].body, LcnfExpr::Unreachable));
329    }
330    #[test]
331    pub(super) fn test_unreachable_nested() {
332        let mut pass = UnreachableCodeEliminationPass::new();
333        let body = mk_let(
334            0,
335            LcnfLetValue::Lit(LcnfLit::Nat(1)),
336            mk_let(1, LcnfLetValue::Lit(LcnfLit::Nat(2)), LcnfExpr::Unreachable),
337        );
338        let mut decls = vec![mk_fun_decl("f", body)];
339        pass.run(&mut decls);
340        assert!(pass.eliminated >= 2);
341    }
342    #[test]
343    pub(super) fn test_expr_size_count_lets() {
344        let body = mk_let(
345            0,
346            LcnfLetValue::Lit(LcnfLit::Nat(1)),
347            mk_let(
348                1,
349                LcnfLetValue::Lit(LcnfLit::Nat(2)),
350                LcnfExpr::Return(LcnfArg::Var(vid(1))),
351            ),
352        );
353        assert_eq!(ExprSizeEstimator::count_lets(&body), 2);
354    }
355    #[test]
356    pub(super) fn test_expr_size_count_cases() {
357        let body = LcnfExpr::Case {
358            scrutinee: vid(0),
359            scrutinee_ty: LcnfType::Nat,
360            alts: vec![],
361            default: Some(Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))))),
362        };
363        assert_eq!(ExprSizeEstimator::count_cases(&body), 1);
364    }
365    #[test]
366    pub(super) fn test_expr_size_complexity() {
367        let body = mk_let(
368            0,
369            LcnfLetValue::Lit(LcnfLit::Nat(1)),
370            LcnfExpr::Return(LcnfArg::Var(vid(0))),
371        );
372        assert_eq!(ExprSizeEstimator::complexity(&body), 1);
373    }
374    #[test]
375    pub(super) fn test_expr_size_max_depth() {
376        let body = mk_let(
377            0,
378            LcnfLetValue::Lit(LcnfLit::Nat(1)),
379            mk_let(
380                1,
381                LcnfLetValue::Lit(LcnfLit::Nat(2)),
382                LcnfExpr::Return(LcnfArg::Var(vid(1))),
383            ),
384        );
385        assert_eq!(ExprSizeEstimator::max_depth(&body), 2);
386    }
387    #[test]
388    pub(super) fn test_expr_size_is_trivial() {
389        assert!(ExprSizeEstimator::is_trivial(&LcnfExpr::Return(
390            LcnfArg::Lit(LcnfLit::Nat(0))
391        )));
392        assert!(ExprSizeEstimator::is_trivial(&LcnfExpr::Unreachable));
393        assert!(!ExprSizeEstimator::is_trivial(&mk_let(
394            0,
395            LcnfLetValue::Lit(LcnfLit::Nat(0)),
396            LcnfExpr::Unreachable
397        )));
398    }
399    #[test]
400    pub(super) fn test_expr_size_should_inline() {
401        let small = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
402        assert!(ExprSizeEstimator::should_inline(&small, 5));
403        let big = mk_let(
404            0,
405            LcnfLetValue::Lit(LcnfLit::Nat(1)),
406            mk_let(
407                1,
408                LcnfLetValue::Lit(LcnfLit::Nat(2)),
409                mk_let(
410                    2,
411                    LcnfLetValue::Lit(LcnfLit::Nat(3)),
412                    mk_let(
413                        3,
414                        LcnfLetValue::Lit(LcnfLit::Nat(4)),
415                        mk_let(
416                            4,
417                            LcnfLetValue::Lit(LcnfLit::Nat(5)),
418                            mk_let(
419                                5,
420                                LcnfLetValue::Lit(LcnfLit::Nat(6)),
421                                LcnfExpr::Return(LcnfArg::Var(vid(5))),
422                            ),
423                        ),
424                    ),
425                ),
426            ),
427        );
428        assert!(!ExprSizeEstimator::should_inline(&big, 3));
429    }
430    #[test]
431    pub(super) fn test_expr_size_var_refs() {
432        let body = mk_let(
433            1,
434            LcnfLetValue::FVar(vid(0)),
435            LcnfExpr::Return(LcnfArg::Var(vid(1))),
436        );
437        assert_eq!(ExprSizeEstimator::count_var_refs(&body), 2);
438    }
439    #[test]
440    pub(super) fn test_pgo_hints() {
441        let mut hints = PgoHints::new();
442        assert!(!hints.is_hot("foo"));
443        assert!(!hints.should_inline("foo"));
444        hints.mark_hot("foo");
445        hints.mark_hot("bar");
446        hints.mark_hot("foo");
447        assert!(hints.is_hot("foo"));
448        assert!(hints.is_hot("bar"));
449        assert_eq!(hints.hot_functions.len(), 2);
450        hints.mark_inline("baz");
451        assert!(hints.should_inline("baz"));
452        assert!(!hints.should_inline("qux"));
453    }
454    #[test]
455    pub(super) fn test_pgo_hints_cold() {
456        let mut hints = PgoHints::new();
457        hints.mark_cold("cold_fn");
458        assert!(hints.is_cold("cold_fn"));
459        assert!(!hints.is_cold("other"));
460    }
461    #[test]
462    pub(super) fn test_pgo_hints_total() {
463        let mut hints = PgoHints::new();
464        hints.mark_hot("a");
465        hints.mark_cold("b");
466        hints.mark_inline("c");
467        hints.record_call("d", 10);
468        assert_eq!(hints.total_hints(), 4);
469    }
470    #[test]
471    pub(super) fn test_pgo_hints_classify() {
472        let mut hints = PgoHints::new();
473        hints.mark_hot("h");
474        hints.mark_cold("c");
475        assert_eq!(hints.classify("h"), "hot");
476        assert_eq!(hints.classify("c"), "cold");
477        assert_eq!(hints.classify("other"), "normal");
478    }
479    #[test]
480    pub(super) fn test_pgo_hints_merge() {
481        let mut h1 = PgoHints::new();
482        h1.mark_hot("a");
483        h1.record_call("f", 5);
484        let mut h2 = PgoHints::new();
485        h2.mark_hot("b");
486        h2.mark_cold("c");
487        h2.record_call("f", 3);
488        h1.merge(&h2);
489        assert!(h1.is_hot("a"));
490        assert!(h1.is_hot("b"));
491        assert!(h1.is_cold("c"));
492        assert_eq!(h1.call_count("f"), 8);
493    }
494    #[test]
495    pub(super) fn test_pgo_hints_call_count() {
496        let mut hints = PgoHints::new();
497        hints.record_call("f", 10);
498        hints.record_call("f", 5);
499        assert_eq!(hints.call_count("f"), 15);
500        assert_eq!(hints.call_count("g"), 0);
501    }
502    #[test]
503    pub(super) fn test_inline_cost_estimator_trivial() {
504        let est = InlineCostEstimator::default();
505        let decl = mk_fun_decl("f", LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))));
506        assert!(est.should_inline(&decl, None));
507    }
508    #[test]
509    pub(super) fn test_inline_cost_estimator_with_pgo() {
510        let est = InlineCostEstimator::default();
511        let body = mk_let(
512            0,
513            LcnfLetValue::Lit(LcnfLit::Nat(1)),
514            mk_let(
515                1,
516                LcnfLetValue::Lit(LcnfLit::Nat(2)),
517                mk_let(
518                    2,
519                    LcnfLetValue::Lit(LcnfLit::Nat(3)),
520                    mk_let(
521                        3,
522                        LcnfLetValue::Lit(LcnfLit::Nat(4)),
523                        LcnfExpr::Return(LcnfArg::Var(vid(3))),
524                    ),
525                ),
526            ),
527        );
528        let decl = mk_fun_decl("medium", body);
529        let mut pgo = PgoHints::new();
530        pgo.mark_inline("medium");
531        assert!(est.should_inline(&decl, Some(&pgo)));
532    }
533    #[test]
534    pub(super) fn test_inline_cost_recursive_penalty() {
535        let est = InlineCostEstimator::default();
536        let mut decl = mk_fun_decl("rec", LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))));
537        decl.is_recursive = true;
538        let cost = est.cost(&decl);
539        assert_eq!(cost, 10);
540    }
541    #[test]
542    pub(super) fn test_pass_manager_new() {
543        let pm = PassManager::new();
544        assert_eq!(pm.num_passes(), 0);
545        assert_eq!(pm.max_iterations, 10);
546    }
547    #[test]
548    pub(super) fn test_pass_manager_add_pass() {
549        let mut pm = PassManager::new();
550        pm.add_pass("dce");
551        pm.add_pass("cp");
552        pm.add_pass("dce");
553        assert_eq!(pm.num_passes(), 2);
554    }
555    #[test]
556    pub(super) fn test_pass_manager_record_run() {
557        let mut pm = PassManager::new();
558        pm.add_pass("dce");
559        pm.record_run("dce", 5, 100);
560        let stats = pm.get_stats("dce").expect("stats should exist");
561        assert_eq!(stats.run_count, 1);
562        assert_eq!(stats.total_changes, 5);
563    }
564    #[test]
565    pub(super) fn test_pass_manager_topological_order() {
566        let mut pm = PassManager::new();
567        pm.add_pass("beta");
568        pm.add_pass("dce");
569        pm.add_pass("cp");
570        pm.add_dependency("dce", "cp");
571        pm.add_dependency("cp", "beta");
572        let order = pm.topological_order().expect("no cycle");
573        let beta_pos = order
574            .iter()
575            .position(|n| n == "beta")
576            .expect("beta_pos position should exist");
577        let cp_pos = order
578            .iter()
579            .position(|n| n == "cp")
580            .expect("cp_pos position should exist");
581        let dce_pos = order
582            .iter()
583            .position(|n| n == "dce")
584            .expect("dce_pos position should exist");
585        assert!(beta_pos < cp_pos);
586        assert!(cp_pos < dce_pos);
587    }
588    #[test]
589    pub(super) fn test_pass_manager_cycle_detection() {
590        let mut pm = PassManager::new();
591        pm.add_pass("a");
592        pm.add_pass("b");
593        pm.add_dependency("a", "b");
594        pm.add_dependency("b", "a");
595        assert!(pm.has_cycle());
596        assert!(pm.topological_order().is_none());
597    }
598    #[test]
599    pub(super) fn test_pass_manager_no_cycle() {
600        let mut pm = PassManager::new();
601        pm.add_pass("a");
602        pm.add_pass("b");
603        pm.add_dependency("b", "a");
604        assert!(!pm.has_cycle());
605    }
606    #[test]
607    pub(super) fn test_pass_manager_total_changes() {
608        let mut pm = PassManager::new();
609        pm.add_pass("a");
610        pm.add_pass("b");
611        pm.record_run("a", 3, 0);
612        pm.record_run("b", 7, 0);
613        assert_eq!(pm.total_changes(), 10);
614        assert_eq!(pm.total_runs(), 2);
615    }
616    #[test]
617    pub(super) fn test_pass_stats_display() {
618        let mut stats = PassStats::new("test_pass");
619        stats.record_run(5, 100);
620        stats.record_run(3, 50);
621        let s = format!("{}", stats);
622        assert!(s.contains("test_pass"));
623        assert!(s.contains("runs=2"));
624        assert!(s.contains("changes=8"));
625    }
626    #[test]
627    pub(super) fn test_pass_stats_avg() {
628        let mut stats = PassStats::new("avg_test");
629        stats.record_run(10, 0);
630        stats.record_run(20, 0);
631        assert!((stats.avg_changes() - 15.0).abs() < 0.001);
632    }
633    #[test]
634    pub(super) fn test_pass_stats_empty_avg() {
635        let stats = PassStats::new("empty");
636        assert_eq!(stats.avg_changes(), 0.0);
637    }
638    #[test]
639    pub(super) fn test_pass_dependency_display() {
640        let dep = PassDependency::new("b", "a");
641        assert_eq!(format!("{}", dep), "a -> b");
642    }
643    #[test]
644    pub(super) fn test_substitute_var_in_return() {
645        let mut expr = LcnfExpr::Return(LcnfArg::Var(vid(1)));
646        substitute_var_in_expr(&mut expr, vid(1), vid(2));
647        assert_eq!(expr, LcnfExpr::Return(LcnfArg::Var(vid(2))));
648    }
649    #[test]
650    pub(super) fn test_substitute_var_in_tailcall() {
651        let mut expr = LcnfExpr::TailCall(
652            LcnfArg::Var(vid(1)),
653            vec![LcnfArg::Var(vid(1)), LcnfArg::Lit(LcnfLit::Nat(0))],
654        );
655        substitute_var_in_expr(&mut expr, vid(1), vid(2));
656        if let LcnfExpr::TailCall(f, args) = &expr {
657            assert_eq!(*f, LcnfArg::Var(vid(2)));
658            assert_eq!(args[0], LcnfArg::Var(vid(2)));
659        }
660    }
661    #[test]
662    pub(super) fn test_substitute_var_in_case() {
663        let mut expr = LcnfExpr::Case {
664            scrutinee: vid(1),
665            scrutinee_ty: LcnfType::Nat,
666            alts: vec![],
667            default: Some(Box::new(LcnfExpr::Return(LcnfArg::Var(vid(1))))),
668        };
669        substitute_var_in_expr(&mut expr, vid(1), vid(2));
670        if let LcnfExpr::Case {
671            scrutinee, default, ..
672        } = &expr
673        {
674            assert_eq!(*scrutinee, vid(2));
675            assert_eq!(
676                **default.as_ref().expect("expected Some/Ok value"),
677                LcnfExpr::Return(LcnfArg::Var(vid(2)))
678            );
679        }
680    }
681    #[test]
682    pub(super) fn test_run_all_passes() {
683        let mut hints = PgoHints::new();
684        hints.mark_hot("main");
685        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
686        let mut decls = vec![mk_fun_decl("main", body)];
687        run_all_passes(&mut decls, Some(&hints));
688        run_all_passes(&mut decls, None);
689    }
690    #[test]
691    pub(super) fn test_opt_pass_trait_constant_folding() {
692        let mut pass = ConstantFoldingPass::new();
693        assert_eq!(pass.name(), "constant_folding");
694        assert!(pass.is_enabled());
695        assert!(pass.dependencies().is_empty());
696        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
697        let mut decls = vec![mk_fun_decl("f", body)];
698        let changes = pass.run_pass(&mut decls);
699        assert_eq!(changes, 0);
700    }
701    #[test]
702    pub(super) fn test_opt_pass_trait_dce() {
703        let mut pass = DeadCodeEliminationPass::new();
704        assert_eq!(pass.name(), "dead_code_elimination");
705    }
706    #[test]
707    pub(super) fn test_opt_pass_trait_cp() {
708        let mut pass = CopyPropagationPass::new();
709        assert_eq!(pass.name(), "copy_propagation");
710    }
711    #[test]
712    pub(super) fn test_opt_pass_trait_beta() {
713        let mut pass = BetaReductionPass::new();
714        assert_eq!(pass.name(), "beta_reduction");
715    }
716    #[test]
717    pub(super) fn test_opt_pass_trait_identity() {
718        let mut pass = IdentityEliminationPass::new();
719        assert_eq!(pass.name(), "identity_elimination");
720    }
721    #[test]
722    pub(super) fn test_opt_pass_trait_unreachable() {
723        let mut pass = UnreachableCodeEliminationPass::new();
724        assert_eq!(pass.name(), "unreachable_code_elimination");
725    }
726    #[test]
727    pub(super) fn test_pass_debug_impls() {
728        let cf = ConstantFoldingPass::new();
729        let dce = DeadCodeEliminationPass::new();
730        let cp = CopyPropagationPass::new();
731        let beta = BetaReductionPass::new();
732        let id = IdentityEliminationPass::new();
733        let sr = StrengthReductionPass::new();
734        let uce = UnreachableCodeEliminationPass::new();
735        assert!(format!("{:?}", cf).contains("ConstantFolding"));
736        assert!(format!("{:?}", dce).contains("DeadCode"));
737        assert!(format!("{:?}", cp).contains("CopyPropagation"));
738        assert!(format!("{:?}", beta).contains("BetaReduction"));
739        assert!(format!("{:?}", id).contains("Identity"));
740        assert!(format!("{:?}", sr).contains("StrengthReduction"));
741        assert!(format!("{:?}", uce).contains("Unreachable"));
742    }
743}
744#[cfg(test)]
745mod OP_infra_tests {
746    use super::*;
747    #[test]
748    pub(super) fn test_pass_config() {
749        let config = OPPassConfig::new("test_pass", OPPassPhase::Transformation);
750        assert!(config.enabled);
751        assert!(config.phase.is_modifying());
752        assert_eq!(config.phase.name(), "transformation");
753    }
754    #[test]
755    pub(super) fn test_pass_stats() {
756        let mut stats = OPPassStats::new();
757        stats.record_run(10, 100, 3);
758        stats.record_run(20, 200, 5);
759        assert_eq!(stats.total_runs, 2);
760        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
761        assert!((stats.success_rate() - 1.0).abs() < 0.01);
762        let s = stats.format_summary();
763        assert!(s.contains("Runs: 2/2"));
764    }
765    #[test]
766    pub(super) fn test_pass_registry() {
767        let mut reg = OPPassRegistry::new();
768        reg.register(OPPassConfig::new("pass_a", OPPassPhase::Analysis));
769        reg.register(OPPassConfig::new("pass_b", OPPassPhase::Transformation).disabled());
770        assert_eq!(reg.total_passes(), 2);
771        assert_eq!(reg.enabled_count(), 1);
772        reg.update_stats("pass_a", 5, 50, 2);
773        let stats = reg.get_stats("pass_a").expect("stats should exist");
774        assert_eq!(stats.total_changes, 5);
775    }
776    #[test]
777    pub(super) fn test_analysis_cache() {
778        let mut cache = OPAnalysisCache::new(10);
779        cache.insert("key1".to_string(), vec![1, 2, 3]);
780        assert!(cache.get("key1").is_some());
781        assert!(cache.get("key2").is_none());
782        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
783        cache.invalidate("key1");
784        assert!(!cache.entries["key1"].valid);
785        assert_eq!(cache.size(), 1);
786    }
787    #[test]
788    pub(super) fn test_worklist() {
789        let mut wl = OPWorklist::new();
790        assert!(wl.push(1));
791        assert!(wl.push(2));
792        assert!(!wl.push(1));
793        assert_eq!(wl.len(), 2);
794        assert_eq!(wl.pop(), Some(1));
795        assert!(!wl.contains(1));
796        assert!(wl.contains(2));
797    }
798    #[test]
799    pub(super) fn test_dominator_tree() {
800        let mut dt = OPDominatorTree::new(5);
801        dt.set_idom(1, 0);
802        dt.set_idom(2, 0);
803        dt.set_idom(3, 1);
804        assert!(dt.dominates(0, 3));
805        assert!(dt.dominates(1, 3));
806        assert!(!dt.dominates(2, 3));
807        assert!(dt.dominates(3, 3));
808    }
809    #[test]
810    pub(super) fn test_liveness() {
811        let mut liveness = OPLivenessInfo::new(3);
812        liveness.add_def(0, 1);
813        liveness.add_use(1, 1);
814        assert!(liveness.defs[0].contains(&1));
815        assert!(liveness.uses[1].contains(&1));
816    }
817    #[test]
818    pub(super) fn test_constant_folding() {
819        assert_eq!(OPConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
820        assert_eq!(OPConstantFoldingHelper::fold_div_i64(10, 0), None);
821        assert_eq!(OPConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
822        assert_eq!(
823            OPConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
824            0b1000
825        );
826        assert_eq!(OPConstantFoldingHelper::fold_bitnot_i64(0), -1);
827    }
828    #[test]
829    pub(super) fn test_dep_graph() {
830        let mut g = OPDepGraph::new();
831        g.add_dep(1, 2);
832        g.add_dep(2, 3);
833        g.add_dep(1, 3);
834        assert_eq!(g.dependencies_of(2), vec![1]);
835        let topo = g.topological_sort();
836        assert_eq!(topo.len(), 3);
837        assert!(!g.has_cycle());
838        let pos: std::collections::HashMap<u32, usize> =
839            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
840        assert!(pos[&1] < pos[&2]);
841        assert!(pos[&1] < pos[&3]);
842        assert!(pos[&2] < pos[&3]);
843    }
844}