Skip to main content

oxilean_codegen/opt_strength/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::*;
6use std::collections::HashMap;
7
8use super::types::{
9    InductionVariable, LinearFunction, OSAnalysisCache, OSConstantFoldingHelper, OSDepGraph,
10    OSDominatorTree, OSLivenessInfo, OSPassConfig, OSPassPhase, OSPassRegistry, OSPassStats,
11    OSWorklist, SRExtCache, SRExtConstFolder, SRExtDepGraph, SRExtDomTree, SRExtLiveness,
12    SRExtPassConfig, SRExtPassPhase, SRExtPassRegistry, SRExtPassStats, SRExtWorklist, SRX2Cache,
13    SRX2ConstFolder, SRX2DepGraph, SRX2DomTree, SRX2Liveness, SRX2PassConfig, SRX2PassPhase,
14    SRX2PassRegistry, SRX2PassStats, SRX2Worklist, StrengthConfig, StrengthReduceRule,
15    StrengthReductionPass, StrengthReport,
16};
17
18/// Collect all let-bindings of the form `y = a * iv + b`.
19pub(super) fn collect_linear_uses(
20    expr: &LcnfExpr,
21    iv: LcnfVarId,
22) -> HashMap<LcnfVarId, LinearFunction> {
23    let mut map = HashMap::new();
24    collect_linear_uses_inner(expr, iv, &mut map);
25    map
26}
27pub(super) fn collect_linear_uses_inner(
28    expr: &LcnfExpr,
29    iv: LcnfVarId,
30    map: &mut HashMap<LcnfVarId, LinearFunction>,
31) {
32    match expr {
33        LcnfExpr::Let {
34            id, value, body, ..
35        } => {
36            if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(fname)), args) = value {
37                if fname == "mul" && args.len() == 2 {
38                    let (a_coeff, has_iv) = extract_mul_iv(iv, args);
39                    if has_iv {
40                        map.insert(*id, LinearFunction::new(iv, a_coeff, 0));
41                    }
42                }
43            }
44            collect_linear_uses_inner(body, iv, map);
45        }
46        LcnfExpr::Case { alts, default, .. } => {
47            for alt in alts {
48                collect_linear_uses_inner(&alt.body, iv, map);
49            }
50            if let Some(d) = default {
51                collect_linear_uses_inner(d, iv, map);
52            }
53        }
54        LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(_, _) => {}
55    }
56}
57pub(super) fn extract_mul_iv(iv: LcnfVarId, args: &[LcnfArg]) -> (i64, bool) {
58    match (&args[0], &args[1]) {
59        (LcnfArg::Var(v), LcnfArg::Lit(LcnfLit::Nat(c))) if *v == iv => (*c as i64, true),
60        (LcnfArg::Lit(LcnfLit::Nat(c)), LcnfArg::Var(v)) if *v == iv => (*c as i64, true),
61        _ => (0, false),
62    }
63}
64pub(super) fn rewrite_linear_uses(
65    expr: &LcnfExpr,
66    _linears: &HashMap<LcnfVarId, LinearFunction>,
67    _iv: &InductionVariable,
68) -> LcnfExpr {
69    expr.clone()
70}
71/// Return `true` iff `n` is a power of two (and > 0).
72pub fn is_power_of_two(n: u64) -> bool {
73    n > 0 && (n & (n - 1)) == 0
74}
75/// Integer log2 for a power-of-two value.
76pub fn log2(n: u64) -> u32 {
77    debug_assert!(is_power_of_two(n));
78    n.trailing_zeros()
79}
80/// Extract the variable argument from an App (the operand that is not a
81/// literal constant).
82pub(super) fn var_arg_of(func: &LcnfArg, args: &[LcnfArg]) -> LcnfArg {
83    if args.len() == 2 {
84        if const_arg(&args[0]).is_some() {
85            return args[1].clone();
86        }
87        if const_arg(&args[1]).is_some() {
88            return args[0].clone();
89        }
90        return args[0].clone();
91    }
92    if args.len() == 1 {
93        return args[0].clone();
94    }
95    func.clone()
96}
97/// If `arg` is a compile-time `Nat` literal, return its value.
98pub(super) fn const_arg(arg: &LcnfArg) -> Option<u64> {
99    match arg {
100        LcnfArg::Lit(LcnfLit::Nat(n)) => Some(*n),
101        _ => None,
102    }
103}
104/// Decompose `c` into a sum of signed powers-of-two within `budget` terms.
105///
106/// Returns `Some(ops)` where each op is `(shift, sign)`, meaning
107/// `x * c = sum_i (x << ops[i].0) * ops[i].1`.
108///
109/// Returns `None` if more than `budget` terms would be needed.
110pub fn decompose_mul(c: u64, budget: u32) -> Option<Vec<(u32, i64)>> {
111    if c == 0 {
112        return Some(vec![]);
113    }
114    let mut ops: Vec<(u32, i64)> = vec![];
115    let mut val = c as i64;
116    let mut bit = 0u32;
117    while val != 0 {
118        if val & 1 != 0 {
119            if (val & 3) == 3 {
120                ops.push((bit, -1));
121                val += 1;
122            } else {
123                ops.push((bit, 1));
124                val -= 1;
125            }
126        }
127        val >>= 1;
128        bit += 1;
129    }
130    if ops.len() as u32 <= budget {
131        Some(ops)
132    } else {
133        None
134    }
135}
136/// Run strength reduction on a vector of declarations with default config.
137pub fn optimize_strength(decls: &mut [LcnfFunDecl]) {
138    StrengthReductionPass::default().run(decls);
139}
140#[cfg(test)]
141mod tests {
142    use super::*;
143    pub(super) fn make_var(id: u64) -> LcnfVarId {
144        LcnfVarId(id)
145    }
146    pub(super) fn app(op: &str, args: Vec<LcnfArg>) -> LcnfLetValue {
147        LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(op.into())), args)
148    }
149    pub(super) fn nat(n: u64) -> LcnfArg {
150        LcnfArg::Lit(LcnfLit::Nat(n))
151    }
152    pub(super) fn var(id: u64) -> LcnfArg {
153        LcnfArg::Var(make_var(id))
154    }
155    pub(super) fn let_expr(id: u64, value: LcnfLetValue, body: LcnfExpr) -> LcnfExpr {
156        LcnfExpr::Let {
157            id: make_var(id),
158            name: format!("v{}", id),
159            ty: LcnfType::Nat,
160            value,
161            body: Box::new(body),
162        }
163    }
164    pub(super) fn ret(id: u64) -> LcnfExpr {
165        LcnfExpr::Return(var(id))
166    }
167    #[test]
168    pub(super) fn test_pow2_detection() {
169        assert!(is_power_of_two(1));
170        assert!(is_power_of_two(2));
171        assert!(is_power_of_two(4));
172        assert!(is_power_of_two(8));
173        assert!(is_power_of_two(16));
174        assert!(is_power_of_two(1024));
175        assert!(!is_power_of_two(0));
176        assert!(!is_power_of_two(3));
177        assert!(!is_power_of_two(6));
178        assert!(!is_power_of_two(7));
179    }
180    #[test]
181    pub(super) fn test_log2() {
182        assert_eq!(log2(1), 0);
183        assert_eq!(log2(2), 1);
184        assert_eq!(log2(4), 2);
185        assert_eq!(log2(8), 3);
186        assert_eq!(log2(256), 8);
187    }
188    #[test]
189    pub(super) fn test_mul_by_pow2_becomes_shl() {
190        let expr = let_expr(1, app("mul", vec![var(0), nat(4)]), ret(1));
191        let mut pass = StrengthReductionPass::default();
192        let result = pass.reduce_expr(&expr);
193        if let LcnfExpr::Let { value, .. } = &result {
194            match value {
195                LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(name)), args) => {
196                    assert_eq!(name, "shl");
197                    assert_eq!(args[1], nat(2));
198                }
199                _ => panic!("expected App(shl, ...)"),
200            }
201        }
202        assert_eq!(pass.report().mul_reduced, 1);
203    }
204    #[test]
205    pub(super) fn test_mul_by_pow2_const_on_left() {
206        let expr = let_expr(1, app("mul", vec![nat(8), var(0)]), ret(1));
207        let mut pass = StrengthReductionPass::default();
208        let result = pass.reduce_expr(&expr);
209        if let LcnfExpr::Let { value, .. } = &result {
210            if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(name)), args) = value {
211                assert_eq!(name, "shl");
212                assert_eq!(args[1], nat(3));
213            } else {
214                panic!("expected shl");
215            }
216        }
217    }
218    #[test]
219    pub(super) fn test_div_by_pow2_becomes_lshr() {
220        let expr = let_expr(1, app("div", vec![var(0), nat(16)]), ret(1));
221        let mut pass = StrengthReductionPass::default();
222        pass.reduce_expr(&expr);
223        assert_eq!(pass.report().div_reduced, 1);
224    }
225    #[test]
226    pub(super) fn test_mod_by_pow2_becomes_band() {
227        let expr = let_expr(1, app("mod", vec![var(0), nat(8)]), ret(1));
228        let mut pass = StrengthReductionPass::default();
229        let result = pass.reduce_expr(&expr);
230        if let LcnfExpr::Let { value, .. } = &result {
231            if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(name)), args) = value {
232                assert_eq!(name, "band");
233                assert_eq!(args[1], nat(7));
234            } else {
235                panic!("expected band");
236            }
237        }
238    }
239    #[test]
240    pub(super) fn test_pow2_const_becomes_mul() {
241        let expr = let_expr(1, app("pow", vec![var(0), nat(2)]), ret(1));
242        let mut pass = StrengthReductionPass::default();
243        let result = pass.reduce_expr(&expr);
244        if let LcnfExpr::Let { value, .. } = &result {
245            if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(name)), args) = value {
246                assert_eq!(name, "mul");
247                assert_eq!(args[0], var(0));
248                assert_eq!(args[1], var(0));
249            } else {
250                panic!("expected mul(x,x)");
251            }
252        }
253        assert_eq!(pass.report().pow_reduced, 1);
254    }
255    #[test]
256    pub(super) fn test_pow3_const_introduces_prefix() {
257        let expr = let_expr(1, app("pow", vec![var(0), nat(3)]), ret(1));
258        let mut pass = StrengthReductionPass::default();
259        let result = pass.reduce_expr(&expr);
260        if let LcnfExpr::Let {
261            value: v0, body, ..
262        } = &result
263        {
264            if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(n0)), _) = v0 {
265                assert_eq!(n0, "mul");
266            } else {
267                panic!("expected mul for square");
268            }
269            if let LcnfExpr::Let { value: v1, .. } = body.as_ref() {
270                if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(n1)), _) = v1 {
271                    assert_eq!(n1, "mul");
272                } else {
273                    panic!("expected mul for cube");
274                }
275            }
276        }
277        assert_eq!(pass.report().pow_reduced, 1);
278    }
279    #[test]
280    pub(super) fn test_neg_to_sub() {
281        let expr = let_expr(1, app("sub", vec![nat(0), var(0)]), ret(1));
282        let mut pass = StrengthReductionPass::default();
283        let result = pass.reduce_expr(&expr);
284        if let LcnfExpr::Let { value, .. } = &result {
285            if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(name)), _) = value {
286                assert_eq!(name, "neg");
287            } else {
288                panic!("expected neg");
289            }
290        }
291        assert_eq!(pass.report().neg_reduced, 1);
292    }
293    #[test]
294    pub(super) fn test_add1_becomes_incr() {
295        let expr = let_expr(1, app("add", vec![var(0), nat(1)]), ret(1));
296        let mut pass = StrengthReductionPass::default();
297        let result = pass.reduce_expr(&expr);
298        if let LcnfExpr::Let { value, .. } = &result {
299            if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(name)), _) = value {
300                assert_eq!(name, "incr");
301            } else {
302                panic!("expected incr");
303            }
304        }
305        assert_eq!(pass.report().inc_reduced, 1);
306    }
307    #[test]
308    pub(super) fn test_mul_by_non_pow2_constant() {
309        let expr = let_expr(1, app("mul", vec![var(0), nat(3)]), ret(1));
310        let mut pass = StrengthReductionPass::default();
311        pass.reduce_expr(&expr);
312        assert_eq!(pass.report().mul_reduced, 1);
313    }
314    #[test]
315    pub(super) fn test_div_by_constant_magic() {
316        let expr = let_expr(1, app("div", vec![var(0), nat(7)]), ret(1));
317        let mut pass = StrengthReductionPass::default();
318        let result = pass.reduce_expr(&expr);
319        if let LcnfExpr::Let { value, .. } = &result {
320            if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(name)), _) = value {
321                assert_eq!(name, "magic_div");
322            } else {
323                panic!("expected magic_div");
324            }
325        }
326        assert_eq!(pass.report().div_reduced, 1);
327    }
328    #[test]
329    pub(super) fn test_div_by_constant_disabled() {
330        let expr = let_expr(1, app("div", vec![var(0), nat(7)]), ret(1));
331        let mut pass = StrengthReductionPass::new(StrengthConfig {
332            optimize_div: false,
333            ..Default::default()
334        });
335        let result = pass.reduce_expr(&expr);
336        if let LcnfExpr::Let { value, .. } = &result {
337            if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(name)), _) = value {
338                assert_eq!(name, "div");
339            }
340        }
341    }
342    #[test]
343    pub(super) fn test_no_reduction_for_unknown_op() {
344        let expr = let_expr(1, app("custom_op", vec![var(0), nat(4)]), ret(1));
345        let mut pass = StrengthReductionPass::default();
346        pass.reduce_expr(&expr);
347        assert_eq!(pass.report().mul_reduced, 0);
348    }
349    #[test]
350    pub(super) fn test_run_on_empty_decls() {
351        let mut decls: Vec<LcnfFunDecl> = vec![];
352        StrengthReductionPass::default().run(&mut decls);
353    }
354    #[test]
355    pub(super) fn test_decompose_mul_pow2() {
356        let ops = decompose_mul(4, 3).expect("ops decomposition should succeed");
357        assert_eq!(ops.len(), 1);
358        assert_eq!(ops[0].0, 2);
359    }
360    #[test]
361    pub(super) fn test_decompose_mul_3() {
362        let ops = decompose_mul(3, 3);
363        assert!(ops.is_some());
364        assert!(ops.expect("value should be Some/Ok").len() <= 3);
365    }
366    #[test]
367    pub(super) fn test_decompose_mul_over_budget() {
368        let large = 0b10101010101010u64;
369        let ops = decompose_mul(large, 2);
370        assert!(ops.is_none());
371    }
372    #[test]
373    pub(super) fn test_decompose_mul_zero() {
374        let ops = decompose_mul(0, 3).expect("ops decomposition should succeed");
375        assert!(ops.is_empty());
376    }
377    #[test]
378    pub(super) fn test_detect_induction_vars_no_tail_call() {
379        let decl = LcnfFunDecl {
380            name: "f".into(),
381            original_name: None,
382            params: vec![LcnfParam {
383                id: make_var(0),
384                name: "n".into(),
385                ty: LcnfType::Nat,
386                erased: false,
387                borrowed: false,
388            }],
389            ret_type: LcnfType::Nat,
390            body: LcnfExpr::Return(var(0)),
391            is_recursive: false,
392            is_lifted: false,
393            inline_cost: 1,
394        };
395        let pass = StrengthReductionPass::default();
396        let ivs = pass.detect_induction_vars(&decl);
397        assert!(ivs.is_empty());
398    }
399    #[test]
400    pub(super) fn test_detect_induction_vars_with_tail_call() {
401        let decl = LcnfFunDecl {
402            name: "loop".into(),
403            original_name: None,
404            params: vec![LcnfParam {
405                id: make_var(0),
406                name: "i".into(),
407                ty: LcnfType::Nat,
408                erased: false,
409                borrowed: false,
410            }],
411            ret_type: LcnfType::Nat,
412            body: LcnfExpr::TailCall(LcnfArg::Lit(LcnfLit::Str("loop".into())), vec![var(0)]),
413            is_recursive: true,
414            is_lifted: false,
415            inline_cost: 5,
416        };
417        let pass = StrengthReductionPass::default();
418        let ivs = pass.detect_induction_vars(&decl);
419        assert_eq!(ivs.len(), 1);
420        assert_eq!(ivs[0].var, make_var(0));
421    }
422    #[test]
423    pub(super) fn test_strength_config_display() {
424        let c = StrengthConfig::default();
425        let s = format!("{}", c);
426        assert!(s.contains("max_shift=3"));
427    }
428    #[test]
429    pub(super) fn test_strength_report_display() {
430        let r = StrengthReport {
431            mul_reduced: 2,
432            div_reduced: 1,
433            pow_reduced: 3,
434            iv_reductions: 0,
435            inc_reduced: 1,
436            neg_reduced: 0,
437        };
438        let s = format!("{}", r);
439        assert!(s.contains("mul=2"));
440        assert!(s.contains("div=1"));
441        assert!(s.contains("pow=3"));
442    }
443    #[test]
444    pub(super) fn test_linear_function() {
445        let lf = LinearFunction::new(make_var(0), 3, 5);
446        assert_eq!(lf.eval(2), 11);
447        assert!(!lf.is_identity());
448        let id = LinearFunction::new(make_var(0), 1, 0);
449        assert!(id.is_identity());
450    }
451    #[test]
452    pub(super) fn test_strength_rule_display() {
453        assert_eq!(
454            format!("{}", StrengthReduceRule::MulByPow2(3)),
455            "MulByPow2(3)"
456        );
457        assert_eq!(
458            format!("{}", StrengthReduceRule::DivByPow2(2)),
459            "DivByPow2(2)"
460        );
461        assert_eq!(format!("{}", StrengthReduceRule::Pow2Const), "Pow2Const");
462        assert_eq!(format!("{}", StrengthReduceRule::NegToSub), "NegToSub");
463        assert_eq!(
464            format!("{}", StrengthReduceRule::AddSubToInc),
465            "AddSubToInc"
466        );
467    }
468    #[test]
469    pub(super) fn test_optimize_strength_convenience() {
470        let mut decls: Vec<LcnfFunDecl> = vec![];
471        optimize_strength(&mut decls);
472    }
473}
474#[cfg(test)]
475mod OS_infra_tests {
476    use super::*;
477    #[test]
478    pub(super) fn test_pass_config() {
479        let config = OSPassConfig::new("test_pass", OSPassPhase::Transformation);
480        assert!(config.enabled);
481        assert!(config.phase.is_modifying());
482        assert_eq!(config.phase.name(), "transformation");
483    }
484    #[test]
485    pub(super) fn test_pass_stats() {
486        let mut stats = OSPassStats::new();
487        stats.record_run(10, 100, 3);
488        stats.record_run(20, 200, 5);
489        assert_eq!(stats.total_runs, 2);
490        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
491        assert!((stats.success_rate() - 1.0).abs() < 0.01);
492        let s = stats.format_summary();
493        assert!(s.contains("Runs: 2/2"));
494    }
495    #[test]
496    pub(super) fn test_pass_registry() {
497        let mut reg = OSPassRegistry::new();
498        reg.register(OSPassConfig::new("pass_a", OSPassPhase::Analysis));
499        reg.register(OSPassConfig::new("pass_b", OSPassPhase::Transformation).disabled());
500        assert_eq!(reg.total_passes(), 2);
501        assert_eq!(reg.enabled_count(), 1);
502        reg.update_stats("pass_a", 5, 50, 2);
503        let stats = reg.get_stats("pass_a").expect("stats should exist");
504        assert_eq!(stats.total_changes, 5);
505    }
506    #[test]
507    pub(super) fn test_analysis_cache() {
508        let mut cache = OSAnalysisCache::new(10);
509        cache.insert("key1".to_string(), vec![1, 2, 3]);
510        assert!(cache.get("key1").is_some());
511        assert!(cache.get("key2").is_none());
512        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
513        cache.invalidate("key1");
514        assert!(!cache.entries["key1"].valid);
515        assert_eq!(cache.size(), 1);
516    }
517    #[test]
518    pub(super) fn test_worklist() {
519        let mut wl = OSWorklist::new();
520        assert!(wl.push(1));
521        assert!(wl.push(2));
522        assert!(!wl.push(1));
523        assert_eq!(wl.len(), 2);
524        assert_eq!(wl.pop(), Some(1));
525        assert!(!wl.contains(1));
526        assert!(wl.contains(2));
527    }
528    #[test]
529    pub(super) fn test_dominator_tree() {
530        let mut dt = OSDominatorTree::new(5);
531        dt.set_idom(1, 0);
532        dt.set_idom(2, 0);
533        dt.set_idom(3, 1);
534        assert!(dt.dominates(0, 3));
535        assert!(dt.dominates(1, 3));
536        assert!(!dt.dominates(2, 3));
537        assert!(dt.dominates(3, 3));
538    }
539    #[test]
540    pub(super) fn test_liveness() {
541        let mut liveness = OSLivenessInfo::new(3);
542        liveness.add_def(0, 1);
543        liveness.add_use(1, 1);
544        assert!(liveness.defs[0].contains(&1));
545        assert!(liveness.uses[1].contains(&1));
546    }
547    #[test]
548    pub(super) fn test_constant_folding() {
549        assert_eq!(OSConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
550        assert_eq!(OSConstantFoldingHelper::fold_div_i64(10, 0), None);
551        assert_eq!(OSConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
552        assert_eq!(
553            OSConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
554            0b1000
555        );
556        assert_eq!(OSConstantFoldingHelper::fold_bitnot_i64(0), -1);
557    }
558    #[test]
559    pub(super) fn test_dep_graph() {
560        let mut g = OSDepGraph::new();
561        g.add_dep(1, 2);
562        g.add_dep(2, 3);
563        g.add_dep(1, 3);
564        assert_eq!(g.dependencies_of(2), vec![1]);
565        let topo = g.topological_sort();
566        assert_eq!(topo.len(), 3);
567        assert!(!g.has_cycle());
568        let pos: std::collections::HashMap<u32, usize> =
569            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
570        assert!(pos[&1] < pos[&2]);
571        assert!(pos[&1] < pos[&3]);
572        assert!(pos[&2] < pos[&3]);
573    }
574}
575#[cfg(test)]
576mod srext_pass_tests {
577    use super::*;
578    #[test]
579    pub(super) fn test_srext_phase_order() {
580        assert_eq!(SRExtPassPhase::Early.order(), 0);
581        assert_eq!(SRExtPassPhase::Middle.order(), 1);
582        assert_eq!(SRExtPassPhase::Late.order(), 2);
583        assert_eq!(SRExtPassPhase::Finalize.order(), 3);
584        assert!(SRExtPassPhase::Early.is_early());
585        assert!(!SRExtPassPhase::Early.is_late());
586    }
587    #[test]
588    pub(super) fn test_srext_config_builder() {
589        let c = SRExtPassConfig::new("p")
590            .with_phase(SRExtPassPhase::Late)
591            .with_max_iter(50)
592            .with_debug(1);
593        assert_eq!(c.name, "p");
594        assert_eq!(c.max_iterations, 50);
595        assert!(c.is_debug_enabled());
596        assert!(c.enabled);
597        let c2 = c.disabled();
598        assert!(!c2.enabled);
599    }
600    #[test]
601    pub(super) fn test_srext_stats() {
602        let mut s = SRExtPassStats::new();
603        s.visit();
604        s.visit();
605        s.modify();
606        s.iterate();
607        assert_eq!(s.nodes_visited, 2);
608        assert_eq!(s.nodes_modified, 1);
609        assert!(s.changed);
610        assert_eq!(s.iterations, 1);
611        let e = s.efficiency();
612        assert!((e - 0.5).abs() < 1e-9);
613    }
614    #[test]
615    pub(super) fn test_srext_registry() {
616        let mut r = SRExtPassRegistry::new();
617        r.register(SRExtPassConfig::new("a").with_phase(SRExtPassPhase::Early));
618        r.register(SRExtPassConfig::new("b").disabled());
619        assert_eq!(r.len(), 2);
620        assert_eq!(r.enabled_passes().len(), 1);
621        assert_eq!(r.passes_in_phase(&SRExtPassPhase::Early).len(), 1);
622    }
623    #[test]
624    pub(super) fn test_srext_cache() {
625        let mut c = SRExtCache::new(4);
626        assert!(c.get(99).is_none());
627        c.put(99, vec![1, 2, 3]);
628        let v = c.get(99).expect("v should be present in map");
629        assert_eq!(v, &[1u8, 2, 3]);
630        assert!(c.hit_rate() > 0.0);
631        assert_eq!(c.live_count(), 1);
632    }
633    #[test]
634    pub(super) fn test_srext_worklist() {
635        let mut w = SRExtWorklist::new(10);
636        w.push(5);
637        w.push(3);
638        w.push(5);
639        assert_eq!(w.len(), 2);
640        assert!(w.contains(5));
641        let first = w.pop().expect("first should be available to pop");
642        assert!(!w.contains(first));
643    }
644    #[test]
645    pub(super) fn test_srext_dom_tree() {
646        let mut dt = SRExtDomTree::new(5);
647        dt.set_idom(1, 0);
648        dt.set_idom(2, 0);
649        dt.set_idom(3, 1);
650        dt.set_idom(4, 1);
651        assert!(dt.dominates(0, 3));
652        assert!(dt.dominates(1, 4));
653        assert!(!dt.dominates(2, 3));
654        assert_eq!(dt.depth_of(3), 2);
655    }
656    #[test]
657    pub(super) fn test_srext_liveness() {
658        let mut lv = SRExtLiveness::new(3);
659        lv.add_def(0, 1);
660        lv.add_use(1, 1);
661        assert!(lv.var_is_def_in_block(0, 1));
662        assert!(lv.var_is_used_in_block(1, 1));
663        assert!(!lv.var_is_def_in_block(1, 1));
664    }
665    #[test]
666    pub(super) fn test_srext_const_folder() {
667        let mut cf = SRExtConstFolder::new();
668        assert_eq!(cf.add_i64(3, 4), Some(7));
669        assert_eq!(cf.div_i64(10, 0), None);
670        assert_eq!(cf.mul_i64(6, 7), Some(42));
671        assert_eq!(cf.and_i64(0b1100, 0b1010), 0b1000);
672        assert_eq!(cf.fold_count(), 3);
673        assert_eq!(cf.failure_count(), 1);
674    }
675    #[test]
676    pub(super) fn test_srext_dep_graph() {
677        let mut g = SRExtDepGraph::new(4);
678        g.add_edge(0, 1);
679        g.add_edge(1, 2);
680        g.add_edge(2, 3);
681        assert!(!g.has_cycle());
682        assert_eq!(g.topo_sort(), Some(vec![0, 1, 2, 3]));
683        assert_eq!(g.reachable(0).len(), 4);
684        let sccs = g.scc();
685        assert_eq!(sccs.len(), 4);
686    }
687}
688#[cfg(test)]
689mod srx2_pass_tests {
690    use super::*;
691    #[test]
692    pub(super) fn test_srx2_phase_order() {
693        assert_eq!(SRX2PassPhase::Early.order(), 0);
694        assert_eq!(SRX2PassPhase::Middle.order(), 1);
695        assert_eq!(SRX2PassPhase::Late.order(), 2);
696        assert_eq!(SRX2PassPhase::Finalize.order(), 3);
697        assert!(SRX2PassPhase::Early.is_early());
698        assert!(!SRX2PassPhase::Early.is_late());
699    }
700    #[test]
701    pub(super) fn test_srx2_config_builder() {
702        let c = SRX2PassConfig::new("p")
703            .with_phase(SRX2PassPhase::Late)
704            .with_max_iter(50)
705            .with_debug(1);
706        assert_eq!(c.name, "p");
707        assert_eq!(c.max_iterations, 50);
708        assert!(c.is_debug_enabled());
709        assert!(c.enabled);
710        let c2 = c.disabled();
711        assert!(!c2.enabled);
712    }
713    #[test]
714    pub(super) fn test_srx2_stats() {
715        let mut s = SRX2PassStats::new();
716        s.visit();
717        s.visit();
718        s.modify();
719        s.iterate();
720        assert_eq!(s.nodes_visited, 2);
721        assert_eq!(s.nodes_modified, 1);
722        assert!(s.changed);
723        assert_eq!(s.iterations, 1);
724        let e = s.efficiency();
725        assert!((e - 0.5).abs() < 1e-9);
726    }
727    #[test]
728    pub(super) fn test_srx2_registry() {
729        let mut r = SRX2PassRegistry::new();
730        r.register(SRX2PassConfig::new("a").with_phase(SRX2PassPhase::Early));
731        r.register(SRX2PassConfig::new("b").disabled());
732        assert_eq!(r.len(), 2);
733        assert_eq!(r.enabled_passes().len(), 1);
734        assert_eq!(r.passes_in_phase(&SRX2PassPhase::Early).len(), 1);
735    }
736    #[test]
737    pub(super) fn test_srx2_cache() {
738        let mut c = SRX2Cache::new(4);
739        assert!(c.get(99).is_none());
740        c.put(99, vec![1, 2, 3]);
741        let v = c.get(99).expect("v should be present in map");
742        assert_eq!(v, &[1u8, 2, 3]);
743        assert!(c.hit_rate() > 0.0);
744        assert_eq!(c.live_count(), 1);
745    }
746    #[test]
747    pub(super) fn test_srx2_worklist() {
748        let mut w = SRX2Worklist::new(10);
749        w.push(5);
750        w.push(3);
751        w.push(5);
752        assert_eq!(w.len(), 2);
753        assert!(w.contains(5));
754        let first = w.pop().expect("first should be available to pop");
755        assert!(!w.contains(first));
756    }
757    #[test]
758    pub(super) fn test_srx2_dom_tree() {
759        let mut dt = SRX2DomTree::new(5);
760        dt.set_idom(1, 0);
761        dt.set_idom(2, 0);
762        dt.set_idom(3, 1);
763        dt.set_idom(4, 1);
764        assert!(dt.dominates(0, 3));
765        assert!(dt.dominates(1, 4));
766        assert!(!dt.dominates(2, 3));
767        assert_eq!(dt.depth_of(3), 2);
768    }
769    #[test]
770    pub(super) fn test_srx2_liveness() {
771        let mut lv = SRX2Liveness::new(3);
772        lv.add_def(0, 1);
773        lv.add_use(1, 1);
774        assert!(lv.var_is_def_in_block(0, 1));
775        assert!(lv.var_is_used_in_block(1, 1));
776        assert!(!lv.var_is_def_in_block(1, 1));
777    }
778    #[test]
779    pub(super) fn test_srx2_const_folder() {
780        let mut cf = SRX2ConstFolder::new();
781        assert_eq!(cf.add_i64(3, 4), Some(7));
782        assert_eq!(cf.div_i64(10, 0), None);
783        assert_eq!(cf.mul_i64(6, 7), Some(42));
784        assert_eq!(cf.and_i64(0b1100, 0b1010), 0b1000);
785        assert_eq!(cf.fold_count(), 3);
786        assert_eq!(cf.failure_count(), 1);
787    }
788    #[test]
789    pub(super) fn test_srx2_dep_graph() {
790        let mut g = SRX2DepGraph::new(4);
791        g.add_edge(0, 1);
792        g.add_edge(1, 2);
793        g.add_edge(2, 3);
794        assert!(!g.has_cycle());
795        assert_eq!(g.topo_sort(), Some(vec![0, 1, 2, 3]));
796        assert_eq!(g.reachable(0).len(), 4);
797        let sccs = g.scc();
798        assert_eq!(sccs.len(), 4);
799    }
800}