Skip to main content

oxilean_codegen/opt_loop_unroll/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::{LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfVarId};
6use std::collections::HashMap;
7
8use super::types::{
9    LUAnalysisCache, LUConstantFoldingHelper, LUDepGraph, LUDominatorTree, LULivenessInfo,
10    LUPassConfig, LUPassPhase, LUPassRegistry, LUPassStats, LUWorklist, LoopInfo, LoopUnrollPass,
11    UnrollCandidate, UnrollConfig, UnrollFactor, UnrollReport,
12};
13
14/// Count the number of distinct variable references in an expression.
15pub fn count_var_refs(expr: &LcnfExpr, target: LcnfVarId) -> usize {
16    match expr {
17        LcnfExpr::Let {
18            id, value, body, ..
19        } => {
20            let in_value = count_var_refs_in_value(value, target);
21            let in_body = if *id == target {
22                0
23            } else {
24                count_var_refs(body, target)
25            };
26            in_value + in_body
27        }
28        LcnfExpr::Case {
29            scrutinee,
30            alts,
31            default,
32            ..
33        } => {
34            let scrutinee_count = if *scrutinee == target { 1 } else { 0 };
35            let alt_count: usize = alts
36                .iter()
37                .filter(|a| a.params.iter().all(|p| p.id != target))
38                .map(|a| count_var_refs(&a.body, target))
39                .sum();
40            let default_count = default
41                .as_ref()
42                .map(|d| count_var_refs(d, target))
43                .unwrap_or(0);
44            scrutinee_count + alt_count + default_count
45        }
46        LcnfExpr::Return(arg) | LcnfExpr::TailCall(arg, _) => {
47            if let crate::lcnf::LcnfArg::Var(id) = arg {
48                if *id == target {
49                    1
50                } else {
51                    0
52                }
53            } else {
54                0
55            }
56        }
57        LcnfExpr::Unreachable => 0,
58    }
59}
60pub(super) fn count_var_refs_in_value(value: &LcnfLetValue, target: LcnfVarId) -> usize {
61    let count_arg = |a: &crate::lcnf::LcnfArg| {
62        matches!(a, crate ::lcnf::LcnfArg::Var(id) if * id == target) as usize
63    };
64    match value {
65        LcnfLetValue::App(f, args) => count_arg(f) + args.iter().map(count_arg).sum::<usize>(),
66        LcnfLetValue::Proj(_, _, v) => {
67            if *v == target {
68                1
69            } else {
70                0
71            }
72        }
73        LcnfLetValue::Ctor(_, _, args) => args.iter().map(count_arg).sum(),
74        LcnfLetValue::FVar(id) => {
75            if *id == target {
76                1
77            } else {
78                0
79            }
80        }
81        LcnfLetValue::Reset(v) => {
82            if *v == target {
83                1
84            } else {
85                0
86            }
87        }
88        LcnfLetValue::Reuse(slot, _, _, args) => {
89            let s = if *slot == target { 1 } else { 0 };
90            s + args.iter().map(count_arg).sum::<usize>()
91        }
92        LcnfLetValue::Lit(_) | LcnfLetValue::Erased => 0,
93    }
94}
95/// Estimate the abstract instruction count of an LCNF expression.
96pub fn estimate_expr_size(expr: &LcnfExpr) -> u64 {
97    match expr {
98        LcnfExpr::Let { body, .. } => 1 + estimate_expr_size(body),
99        LcnfExpr::Case { alts, default, .. } => {
100            let alt_sizes: u64 = alts.iter().map(|a| estimate_expr_size(&a.body)).sum();
101            let def_size = default.as_ref().map(|d| estimate_expr_size(d)).unwrap_or(0);
102            1 + alt_sizes + def_size
103        }
104        LcnfExpr::Return(_) | LcnfExpr::TailCall(_, _) | LcnfExpr::Unreachable => 1,
105    }
106}
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use crate::lcnf::{
111        LcnfAlt, LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfParam, LcnfType,
112        LcnfVarId,
113    };
114    pub(super) fn make_nat_lit(id: u64, n: u64, body: LcnfExpr) -> LcnfExpr {
115        LcnfExpr::Let {
116            id: LcnfVarId(id),
117            name: format!("v{}", id),
118            ty: LcnfType::Nat,
119            value: LcnfLetValue::Lit(LcnfLit::Nat(n)),
120            body: Box::new(body),
121        }
122    }
123    pub(super) fn make_return_nat(n: u64) -> LcnfExpr {
124        LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(n)))
125    }
126    pub(super) fn make_decl(name: &str, body: LcnfExpr) -> LcnfFunDecl {
127        LcnfFunDecl {
128            name: name.to_string(),
129            original_name: None,
130            params: vec![],
131            ret_type: LcnfType::Nat,
132            body,
133            is_recursive: false,
134            is_lifted: false,
135            inline_cost: 0,
136        }
137    }
138    #[test]
139    pub(super) fn test_unroll_factor_full_has_no_numeric_factor() {
140        assert_eq!(UnrollFactor::Full.factor(), None);
141    }
142    #[test]
143    pub(super) fn test_unroll_factor_partial_returns_factor() {
144        assert_eq!(UnrollFactor::Partial(4).factor(), Some(4));
145    }
146    #[test]
147    pub(super) fn test_unroll_factor_jamming_has_no_numeric_factor() {
148        assert_eq!(UnrollFactor::Jamming.factor(), None);
149    }
150    #[test]
151    pub(super) fn test_unroll_factor_vectorizable_returns_factor() {
152        assert_eq!(UnrollFactor::Vectorizable(8).factor(), Some(8));
153    }
154    #[test]
155    pub(super) fn test_unroll_factor_names() {
156        assert_eq!(UnrollFactor::Full.name(), "full");
157        assert_eq!(UnrollFactor::Partial(2).name(), "partial");
158        assert_eq!(UnrollFactor::Jamming.name(), "jamming");
159        assert_eq!(UnrollFactor::Vectorizable(4).name(), "vectorizable");
160    }
161    #[test]
162    pub(super) fn test_loop_info_trip_count_basic() {
163        let info = LoopInfo::new(LcnfVarId(0), 0, 8, 1, vec![]);
164        assert_eq!(info.trip_count, Some(8));
165    }
166    #[test]
167    pub(super) fn test_loop_info_trip_count_step2() {
168        let info = LoopInfo::new(LcnfVarId(0), 0, 8, 2, vec![]);
169        assert_eq!(info.trip_count, Some(4));
170    }
171    #[test]
172    pub(super) fn test_loop_info_trip_count_non_zero_start() {
173        let info = LoopInfo::new(LcnfVarId(0), 3, 15, 3, vec![]);
174        assert_eq!(info.trip_count, Some(4));
175    }
176    #[test]
177    pub(super) fn test_loop_info_is_counted_when_trip_known() {
178        let info = LoopInfo::new(LcnfVarId(0), 0, 4, 1, vec![]);
179        assert!(info.is_counted);
180    }
181    #[test]
182    pub(super) fn test_loop_info_priority_score_innermost_bonus() {
183        let mut info = LoopInfo::new(LcnfVarId(0), 0, 8, 1, vec![]);
184        info.is_innermost = true;
185        let score_inner = info.priority_score();
186        let mut info2 = LoopInfo::new(LcnfVarId(0), 0, 8, 1, vec![]);
187        info2.is_innermost = false;
188        let score_outer = info2.priority_score();
189        assert!(score_inner > score_outer);
190    }
191    #[test]
192    pub(super) fn test_default_config_values() {
193        let cfg = UnrollConfig::default();
194        assert_eq!(cfg.max_unroll_factor, 8);
195        assert_eq!(cfg.max_unrolled_size, 256);
196        assert_eq!(cfg.unroll_full_threshold, 16);
197        assert!(cfg.enable_vectorizable);
198    }
199    #[test]
200    pub(super) fn test_aggressive_config_larger_limits() {
201        let agg = UnrollConfig::aggressive();
202        let def = UnrollConfig::default();
203        assert!(agg.max_unroll_factor >= def.max_unroll_factor);
204        assert!(agg.max_unrolled_size >= def.max_unrolled_size);
205    }
206    #[test]
207    pub(super) fn test_conservative_config_smaller_limits() {
208        let con = UnrollConfig::conservative();
209        let def = UnrollConfig::default();
210        assert!(con.max_unroll_factor <= def.max_unroll_factor);
211        assert!(!con.enable_vectorizable);
212    }
213    #[test]
214    pub(super) fn test_compute_factor_full_for_small_trip() {
215        let pass = LoopUnrollPass::default_pass();
216        let info = LoopInfo::new(LcnfVarId(0), 0, 4, 1, vec![]);
217        assert_eq!(pass.compute_unroll_factor(&info), UnrollFactor::Full);
218    }
219    #[test]
220    pub(super) fn test_compute_factor_partial_for_medium_trip() {
221        let pass = LoopUnrollPass::default_pass();
222        let mut info = LoopInfo::new(LcnfVarId(0), 0, 32, 1, vec![]);
223        info.estimated_size = 10;
224        let factor = pass.compute_unroll_factor(&info);
225        assert_ne!(factor, UnrollFactor::Full);
226    }
227    #[test]
228    pub(super) fn test_compute_factor_vectorizable_for_divisible_trip() {
229        let mut cfg = UnrollConfig::default();
230        cfg.enable_vectorizable = true;
231        let pass = LoopUnrollPass::new(cfg);
232        let mut info = LoopInfo::new(LcnfVarId(0), 0, 32, 1, vec![]);
233        info.estimated_size = 5;
234        info.is_innermost = true;
235        let factor = pass.compute_unroll_factor(&info);
236        assert!(matches!(factor, UnrollFactor::Vectorizable(_)));
237    }
238    #[test]
239    pub(super) fn test_compute_factor_unknown_trip_gives_partial2() {
240        let pass = LoopUnrollPass::default_pass();
241        let info = LoopInfo {
242            loop_var: LcnfVarId(0),
243            start: 0,
244            end: 0,
245            step: 0,
246            body: vec![],
247            trip_count: None,
248            is_innermost: true,
249            is_counted: false,
250            estimated_size: 10,
251        };
252        assert_eq!(pass.compute_unroll_factor(&info), UnrollFactor::Partial(2));
253    }
254    #[test]
255    pub(super) fn test_unroll_loop_partial_2_doubles_body() {
256        let mut pass = LoopUnrollPass::default_pass();
257        let body = vec![make_return_nat(0), make_return_nat(1)];
258        let result = pass.unroll_loop(&body, &UnrollFactor::Partial(2));
259        assert_eq!(result.len(), body.len() * 2);
260    }
261    #[test]
262    pub(super) fn test_unroll_loop_partial_4_quadruples_body() {
263        let mut pass = LoopUnrollPass::default_pass();
264        let body = vec![make_return_nat(42)];
265        let result = pass.unroll_loop(&body, &UnrollFactor::Partial(4));
266        assert_eq!(result.len(), 4);
267    }
268    #[test]
269    pub(super) fn test_unroll_loop_jamming_returns_unchanged() {
270        let mut pass = LoopUnrollPass::default_pass();
271        let body = vec![make_return_nat(7)];
272        let result = pass.unroll_loop(&body, &UnrollFactor::Jamming);
273        assert_eq!(result.len(), body.len());
274    }
275    #[test]
276    pub(super) fn test_unroll_loop_vectorizable_replicates() {
277        let mut pass = LoopUnrollPass::default_pass();
278        let body = vec![make_return_nat(0)];
279        let result = pass.unroll_loop(&body, &UnrollFactor::Vectorizable(4));
280        assert_eq!(result.len(), 4);
281    }
282    #[test]
283    pub(super) fn test_run_empty_decls() {
284        let mut pass = LoopUnrollPass::default_pass();
285        let mut decls: Vec<LcnfFunDecl> = vec![];
286        pass.run(&mut decls);
287        assert_eq!(pass.report().loops_analyzed, 0);
288    }
289    #[test]
290    pub(super) fn test_run_simple_decl_no_loops() {
291        let mut pass = LoopUnrollPass::default_pass();
292        let decl = make_decl("foo", make_return_nat(0));
293        let mut decls = vec![decl];
294        pass.run(&mut decls);
295        assert_eq!(pass.report().loops_analyzed, 0);
296    }
297    #[test]
298    pub(super) fn test_run_preserves_decl_count() {
299        let mut pass = LoopUnrollPass::default_pass();
300        let d1 = make_decl("f1", make_return_nat(1));
301        let d2 = make_decl("f2", make_return_nat(2));
302        let mut decls = vec![d1, d2];
303        pass.run(&mut decls);
304        assert_eq!(decls.len(), 2);
305    }
306    #[test]
307    pub(super) fn test_report_merge() {
308        let mut r1 = UnrollReport {
309            loops_analyzed: 3,
310            loops_unrolled: 2,
311            full_unrolls: 1,
312            partial_unrolls: 1,
313            jammed_loops: 0,
314            vectorizable_loops: 0,
315            estimated_speedup: 1.5,
316        };
317        let r2 = UnrollReport {
318            loops_analyzed: 7,
319            loops_unrolled: 4,
320            full_unrolls: 2,
321            partial_unrolls: 2,
322            jammed_loops: 2,
323            vectorizable_loops: 0,
324            estimated_speedup: 2.0,
325        };
326        r1.merge(&r2);
327        assert_eq!(r1.loops_analyzed, 10);
328        assert_eq!(r1.loops_unrolled, 6);
329        assert_eq!(r1.jammed_loops, 2);
330    }
331    #[test]
332    pub(super) fn test_report_summary_contains_key_fields() {
333        let r = UnrollReport {
334            loops_analyzed: 5,
335            loops_unrolled: 3,
336            full_unrolls: 1,
337            partial_unrolls: 2,
338            jammed_loops: 0,
339            vectorizable_loops: 0,
340            estimated_speedup: 1.8,
341        };
342        let s = r.summary();
343        assert!(s.contains("analyzed=5"));
344        assert!(s.contains("unrolled=3"));
345    }
346    #[test]
347    pub(super) fn test_estimate_size_return_is_1() {
348        assert_eq!(estimate_expr_size(&make_return_nat(0)), 1);
349    }
350    #[test]
351    pub(super) fn test_estimate_size_let_adds_1() {
352        let e = make_nat_lit(0, 42, make_return_nat(0));
353        assert_eq!(estimate_expr_size(&e), 2);
354    }
355    #[test]
356    pub(super) fn test_estimate_size_chain() {
357        let e = make_nat_lit(0, 1, make_nat_lit(1, 2, make_return_nat(0)));
358        assert_eq!(estimate_expr_size(&e), 3);
359    }
360    #[test]
361    pub(super) fn test_count_var_refs_return() {
362        let e = LcnfExpr::Return(LcnfArg::Var(LcnfVarId(5)));
363        assert_eq!(count_var_refs(&e, LcnfVarId(5)), 1);
364        assert_eq!(count_var_refs(&e, LcnfVarId(6)), 0);
365    }
366    #[test]
367    pub(super) fn test_count_var_refs_in_let_value() {
368        let e = LcnfExpr::Let {
369            id: LcnfVarId(1),
370            name: "x".to_string(),
371            ty: LcnfType::Nat,
372            value: LcnfLetValue::FVar(LcnfVarId(5)),
373            body: Box::new(make_return_nat(0)),
374        };
375        assert_eq!(count_var_refs(&e, LcnfVarId(5)), 1);
376    }
377    #[test]
378    pub(super) fn test_count_var_refs_shadowed() {
379        let e = LcnfExpr::Let {
380            id: LcnfVarId(5),
381            name: "x".to_string(),
382            ty: LcnfType::Nat,
383            value: LcnfLetValue::Erased,
384            body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(5)))),
385        };
386        assert_eq!(count_var_refs(&e, LcnfVarId(5)), 0);
387    }
388    #[test]
389    pub(super) fn test_candidate_is_profitable_positive_savings() {
390        let info = LoopInfo::new(LcnfVarId(0), 0, 4, 1, vec![]);
391        let c = UnrollCandidate::new("f", info, UnrollFactor::Full, 10);
392        assert!(c.is_profitable());
393    }
394    #[test]
395    pub(super) fn test_candidate_is_not_profitable_negative_savings() {
396        let info = LoopInfo::new(LcnfVarId(0), 0, 4, 1, vec![]);
397        let c = UnrollCandidate::new("f", info, UnrollFactor::Full, -5);
398        assert!(!c.is_profitable());
399    }
400    #[test]
401    pub(super) fn test_case_expr_size() {
402        let case_expr = LcnfExpr::Case {
403            scrutinee: LcnfVarId(0),
404            scrutinee_ty: LcnfType::Nat,
405            alts: vec![
406                LcnfAlt {
407                    ctor_name: "zero".to_string(),
408                    ctor_tag: 0,
409                    params: vec![],
410                    body: make_return_nat(0),
411                },
412                LcnfAlt {
413                    ctor_name: "succ".to_string(),
414                    ctor_tag: 1,
415                    params: vec![LcnfParam {
416                        id: LcnfVarId(1),
417                        name: "n".to_string(),
418                        ty: LcnfType::Nat,
419                        erased: false,
420                        borrowed: false,
421                    }],
422                    body: make_return_nat(1),
423                },
424            ],
425            default: None,
426        };
427        assert_eq!(estimate_expr_size(&case_expr), 3);
428    }
429}
430/// Trait for a single loop optimization pass in the pipeline.
431#[allow(dead_code)]
432pub trait LoopOptPass {
433    /// Name of this pass.
434    fn name(&self) -> &str;
435    /// Run this pass on the given function declarations.
436    fn run_pass(&mut self, decls: &mut [LcnfFunDecl]) -> UnrollReport;
437}
438#[cfg(test)]
439mod LU_infra_tests {
440    use super::*;
441    #[test]
442    pub(super) fn test_pass_config() {
443        let config = LUPassConfig::new("test_pass", LUPassPhase::Transformation);
444        assert!(config.enabled);
445        assert!(config.phase.is_modifying());
446        assert_eq!(config.phase.name(), "transformation");
447    }
448    #[test]
449    pub(super) fn test_pass_stats() {
450        let mut stats = LUPassStats::new();
451        stats.record_run(10, 100, 3);
452        stats.record_run(20, 200, 5);
453        assert_eq!(stats.total_runs, 2);
454        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
455        assert!((stats.success_rate() - 1.0).abs() < 0.01);
456        let s = stats.format_summary();
457        assert!(s.contains("Runs: 2/2"));
458    }
459    #[test]
460    pub(super) fn test_pass_registry() {
461        let mut reg = LUPassRegistry::new();
462        reg.register(LUPassConfig::new("pass_a", LUPassPhase::Analysis));
463        reg.register(LUPassConfig::new("pass_b", LUPassPhase::Transformation).disabled());
464        assert_eq!(reg.total_passes(), 2);
465        assert_eq!(reg.enabled_count(), 1);
466        reg.update_stats("pass_a", 5, 50, 2);
467        let stats = reg.get_stats("pass_a").expect("stats should exist");
468        assert_eq!(stats.total_changes, 5);
469    }
470    #[test]
471    pub(super) fn test_analysis_cache() {
472        let mut cache = LUAnalysisCache::new(10);
473        cache.insert("key1".to_string(), vec![1, 2, 3]);
474        assert!(cache.get("key1").is_some());
475        assert!(cache.get("key2").is_none());
476        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
477        cache.invalidate("key1");
478        assert!(!cache.entries["key1"].valid);
479        assert_eq!(cache.size(), 1);
480    }
481    #[test]
482    pub(super) fn test_worklist() {
483        let mut wl = LUWorklist::new();
484        assert!(wl.push(1));
485        assert!(wl.push(2));
486        assert!(!wl.push(1));
487        assert_eq!(wl.len(), 2);
488        assert_eq!(wl.pop(), Some(1));
489        assert!(!wl.contains(1));
490        assert!(wl.contains(2));
491    }
492    #[test]
493    pub(super) fn test_dominator_tree() {
494        let mut dt = LUDominatorTree::new(5);
495        dt.set_idom(1, 0);
496        dt.set_idom(2, 0);
497        dt.set_idom(3, 1);
498        assert!(dt.dominates(0, 3));
499        assert!(dt.dominates(1, 3));
500        assert!(!dt.dominates(2, 3));
501        assert!(dt.dominates(3, 3));
502    }
503    #[test]
504    pub(super) fn test_liveness() {
505        let mut liveness = LULivenessInfo::new(3);
506        liveness.add_def(0, 1);
507        liveness.add_use(1, 1);
508        assert!(liveness.defs[0].contains(&1));
509        assert!(liveness.uses[1].contains(&1));
510    }
511    #[test]
512    pub(super) fn test_constant_folding() {
513        assert_eq!(LUConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
514        assert_eq!(LUConstantFoldingHelper::fold_div_i64(10, 0), None);
515        assert_eq!(LUConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
516        assert_eq!(
517            LUConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
518            0b1000
519        );
520        assert_eq!(LUConstantFoldingHelper::fold_bitnot_i64(0), -1);
521    }
522    #[test]
523    pub(super) fn test_dep_graph() {
524        let mut g = LUDepGraph::new();
525        g.add_dep(1, 2);
526        g.add_dep(2, 3);
527        g.add_dep(1, 3);
528        assert_eq!(g.dependencies_of(2), vec![1]);
529        let topo = g.topological_sort();
530        assert_eq!(topo.len(), 3);
531        assert!(!g.has_cycle());
532        let pos: std::collections::HashMap<u32, usize> =
533            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
534        assert!(pos[&1] < pos[&2]);
535        assert!(pos[&1] < pos[&3]);
536        assert!(pos[&2] < pos[&3]);
537    }
538}