Skip to main content

oxilean_codegen/opt_vectorize/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::{LcnfExpr, LcnfFunDecl, LcnfLetValue};
6use std::collections::HashMap;
7
8use super::types::{
9    CmpOp, DependenceGraph, DependenceKind, LatencyClass, LoopTransformConfig, LoopTransformer,
10    ReductionInfo, ReductionKind, SIMDCostModel, SIMDOp, SIMDTarget, SIMDTargetInfo,
11    StrideAnalysisResult, StridePattern, VecAnalysisCache, VecConstantFoldingHelper, VecDepGraph,
12    VecDominatorTree, VecLivenessInfo, VecPassConfig, VecPassPhase, VecPassRegistry, VecPassStats,
13    VecWorklist, VectorInstr, VectorInstrBuilder, VectorPrologueEpilogue, VectorRegisterFile,
14    VectorScheduler, VectorWidth, VectorizationAnalysis, VectorizationCandidate,
15    VectorizationConfig, VectorizationHint, VectorizationPass, VectorizationPipeline,
16    VectorizationReport,
17};
18
19#[cfg(test)]
20mod tests {
21    use super::*;
22    #[test]
23    pub(super) fn vector_width_lanes() {
24        assert_eq!(VectorWidth::W128.lanes_f32(), 4);
25        assert_eq!(VectorWidth::W256.lanes_f32(), 8);
26        assert_eq!(VectorWidth::W512.lanes_f32(), 16);
27        assert_eq!(VectorWidth::W256.lanes_f64(), 4);
28    }
29    #[test]
30    pub(super) fn simd_target_max_width() {
31        assert_eq!(SIMDTarget::X86AVX.max_width(), VectorWidth::W256);
32        assert_eq!(SIMDTarget::X86AVX512.max_width(), VectorWidth::W512);
33        assert_eq!(SIMDTarget::ArmNeon.max_width(), VectorWidth::W128);
34    }
35    #[test]
36    pub(super) fn candidate_no_dep() {
37        let c = VectorizationCandidate {
38            func_name: "loop_add".to_string(),
39            loop_var: "i".to_string(),
40            loop_bound: Some(1024),
41            array_reads: vec!["a".to_string()],
42            array_writes: vec!["b".to_string()],
43            is_inner_loop: true,
44            has_loop_carried_dep: false,
45        };
46        let analysis = VectorizationAnalysis::new();
47        assert!(analysis.can_vectorize(&c));
48        let speedup = analysis.estimate_speedup(&c, VectorWidth::W256);
49        assert!(speedup > 1.0, "speedup={}", speedup);
50    }
51    #[test]
52    pub(super) fn candidate_with_dep_rejected() {
53        let c = VectorizationCandidate {
54            func_name: "loop_reduce".to_string(),
55            loop_var: "i".to_string(),
56            loop_bound: Some(256),
57            array_reads: vec!["acc".to_string()],
58            array_writes: vec!["acc".to_string()],
59            is_inner_loop: true,
60            has_loop_carried_dep: true,
61        };
62        let analysis = VectorizationAnalysis::new();
63        assert!(!analysis.can_vectorize(&c));
64        assert_eq!(analysis.estimate_speedup(&c, VectorWidth::W256), 1.0);
65    }
66    #[test]
67    pub(super) fn emit_vector_loop_fma() {
68        let config = VectorizationConfig {
69            enable_fma: true,
70            target: SIMDTarget::X86AVX,
71            ..VectorizationConfig::default()
72        };
73        let pass = VectorizationPass::new(config);
74        let candidate = VectorizationCandidate {
75            func_name: "dot_product".to_string(),
76            loop_var: "i".to_string(),
77            loop_bound: Some(512),
78            array_reads: vec!["a".to_string(), "b".to_string()],
79            array_writes: vec!["result".to_string()],
80            is_inner_loop: true,
81            has_loop_carried_dep: false,
82        };
83        let instrs = pass.emit_vector_loop(&candidate, VectorWidth::W256);
84        assert!(!instrs.is_empty());
85        let has_fma = instrs.iter().any(|i| i.op == SIMDOp::Fma);
86        assert!(has_fma, "expected FMA instruction");
87    }
88    #[test]
89    pub(super) fn vector_instr_display() {
90        let instr = VectorInstr::new(
91            SIMDOp::Add,
92            VectorWidth::W128,
93            "v0",
94            vec!["v1".to_string(), "v2".to_string()],
95        );
96        let s = format!("{}", instr);
97        assert!(s.contains("vadd"));
98        assert!(s.contains("128"));
99    }
100    #[test]
101    pub(super) fn report_merge() {
102        let mut r1 = VectorizationReport {
103            loops_analyzed: 3,
104            loops_vectorized: 2,
105            rejected_dep: 1,
106            ..VectorizationReport::default()
107        };
108        let r2 = VectorizationReport {
109            loops_analyzed: 2,
110            loops_vectorized: 1,
111            rejected_trip_count: 1,
112            ..VectorizationReport::default()
113        };
114        r1.merge(&r2);
115        assert_eq!(r1.loops_analyzed, 5);
116        assert_eq!(r1.loops_vectorized, 3);
117        assert_eq!(r1.rejected_dep, 1);
118        assert_eq!(r1.rejected_trip_count, 1);
119    }
120    #[test]
121    pub(super) fn effective_width_caps_at_target() {
122        let config = VectorizationConfig {
123            preferred_width: VectorWidth::W512,
124            target: SIMDTarget::X86SSE,
125            ..VectorizationConfig::default()
126        };
127        let pass = VectorizationPass::new(config);
128        assert_eq!(pass.effective_width(), VectorWidth::W128);
129    }
130}
131/// Returns the abstract latency for a `SIMDOp`.
132#[allow(dead_code)]
133pub fn simd_op_latency(op: &SIMDOp) -> LatencyClass {
134    match op {
135        SIMDOp::Broadcast => LatencyClass::SingleCycle,
136        SIMDOp::Add | SIMDOp::Sub => LatencyClass::Short,
137        SIMDOp::Mul => LatencyClass::Short,
138        SIMDOp::Div => LatencyClass::Medium,
139        SIMDOp::Sqrt => LatencyClass::Long,
140        SIMDOp::Fma => LatencyClass::Short,
141        SIMDOp::Load | SIMDOp::Store => LatencyClass::Memory,
142        SIMDOp::Shuffle | SIMDOp::Blend => LatencyClass::Short,
143        SIMDOp::Compare(_) => LatencyClass::Short,
144        SIMDOp::Min | SIMDOp::Max => LatencyClass::Short,
145        SIMDOp::HorizontalAdd => LatencyClass::Medium,
146    }
147}
148/// A map from function name to a list of vectorization hints.
149#[allow(dead_code)]
150pub type HintMap = HashMap<String, Vec<VectorizationHint>>;
151#[cfg(test)]
152mod extended_tests {
153    use super::*;
154    #[test]
155    pub(super) fn test_vector_register_file_alloc() {
156        let mut rf = VectorRegisterFile::new(4);
157        let r0 = rf.alloc("v");
158        let r1 = rf.alloc("v");
159        let r2 = rf.alloc("v");
160        let r3 = rf.alloc("v");
161        assert_eq!(rf.allocation.len(), 4);
162        assert!(rf.is_full());
163        let _r4 = rf.alloc("v");
164        assert!(rf.spill_count() > 0);
165        rf.free(&r0);
166        assert!(!rf.is_full());
167        let _ = (r1, r2, r3);
168    }
169    #[test]
170    pub(super) fn test_vector_scheduler_ordering() {
171        let mut builder = VectorInstrBuilder::new(VectorWidth::W256);
172        let load_a = builder.load("a_ptr");
173        let load_b = builder.load("b_ptr");
174        let mul = builder.mul(&load_a, &load_b);
175        let _hadd = builder.hadd(&mul);
176        let instrs = builder.build();
177        let scheduled = VectorScheduler::schedule(&instrs);
178        assert_eq!(scheduled.len(), instrs.len());
179        let makespan = VectorScheduler::makespan(&scheduled);
180        assert!(makespan > 0);
181    }
182    #[test]
183    pub(super) fn test_simd_cost_model() {
184        let model = SIMDCostModel::default();
185        let candidate = VectorizationCandidate {
186            func_name: "test".into(),
187            loop_var: "i".into(),
188            loop_bound: Some(1024),
189            array_reads: vec!["a".into(), "b".into()],
190            array_writes: vec!["c".into()],
191            is_inner_loop: true,
192            has_loop_carried_dep: false,
193        };
194        let gain = model.throughput_gain(&candidate, VectorWidth::W256);
195        assert!(gain > 0.0);
196    }
197    #[test]
198    pub(super) fn test_dependence_graph() {
199        let mut dg = DependenceGraph::default();
200        dg.add_edge("a", "b", DependenceKind::True, 0);
201        dg.add_edge("b", "c", DependenceKind::Anti, 1);
202        dg.add_edge("c", "c", DependenceKind::Output, 2);
203        assert!(dg.has_carried_dependence());
204        assert_eq!(dg.max_distance(), 2);
205        assert_eq!(dg.edges_of_kind(DependenceKind::True).len(), 1);
206        assert_eq!(dg.edges_of_kind(DependenceKind::Anti).len(), 1);
207    }
208    #[test]
209    pub(super) fn test_loop_transformer() {
210        let candidate = VectorizationCandidate::new("my_loop", "i");
211        let transformer = LoopTransformer::new();
212        let result = transformer.transform(&candidate, VectorWidth::W256);
213        assert!(result.transformed_name.contains("my_loop"));
214        assert!(result.strip_mined);
215        assert!(result.vector_instr_count > 0);
216    }
217    #[test]
218    pub(super) fn test_vectorization_hints_display() {
219        assert_eq!(VectorizationHint::Force.to_string(), "#[vectorize(force)]");
220        assert_eq!(
221            VectorizationHint::Disable.to_string(),
222            "#[vectorize(disable)]"
223        );
224        assert_eq!(
225            VectorizationHint::Unroll(4).to_string(),
226            "#[vectorize(unroll=4)]"
227        );
228        assert_eq!(
229            VectorizationHint::Width(VectorWidth::W256).to_string(),
230            "#[vectorize(width=256)]"
231        );
232    }
233    #[test]
234    pub(super) fn test_reduction_info() {
235        let sum = ReductionInfo::sum("acc");
236        assert_eq!(sum.kind, ReductionKind::Sum);
237        assert_eq!(sum.initial_value, 0);
238        assert_eq!(sum.reduction_op(), SIMDOp::Add);
239        let prod = ReductionInfo::product("p");
240        assert_eq!(prod.initial_value, 1);
241        assert_eq!(prod.reduction_op(), SIMDOp::Mul);
242    }
243    #[test]
244    pub(super) fn test_vector_instr_builder() {
245        let mut builder = VectorInstrBuilder::new(VectorWidth::W128);
246        let a = builder.load("a_ptr");
247        let b = builder.load("b_ptr");
248        let c = builder.broadcast("scalar");
249        let fma_r = builder.fma(&a, &b, &c);
250        let _hadd = builder.hadd(&fma_r);
251        let instrs = builder.build();
252        assert!(!instrs.is_empty());
253        let has_fma = instrs.iter().any(|i| i.op == SIMDOp::Fma);
254        assert!(has_fma);
255    }
256    #[test]
257    pub(super) fn test_simd_target_info() {
258        let avx512 = SIMDTargetInfo::new(SIMDTarget::X86AVX512);
259        assert_eq!(avx512.num_vector_registers(), 16);
260        assert!(avx512.supports_masking());
261        assert!(avx512.supports_scatter());
262        assert_eq!(avx512.preferred_alignment(), 32);
263        let neon = SIMDTargetInfo::new(SIMDTarget::ArmNeon);
264        assert_eq!(neon.num_vector_registers(), 32);
265        assert!(neon.supports_masking());
266        assert!(!neon.supports_gather());
267    }
268    #[test]
269    pub(super) fn test_prologue_epilogue() {
270        let pe = VectorPrologueEpilogue::new(VectorWidth::W256);
271        assert_eq!(pe.prologue_iterations(0, 4), 0);
272        assert_eq!(pe.prologue_iterations(12, 4), 5);
273        assert_eq!(pe.epilogue_iterations(100, 4), 0);
274        assert_eq!(pe.epilogue_iterations(101, 4), 1);
275    }
276    #[test]
277    pub(super) fn test_stride_pattern_display() {
278        assert_eq!(StridePattern::Unit.to_string(), "unit");
279        assert_eq!(StridePattern::Constant(2).to_string(), "const(2)");
280        assert_eq!(StridePattern::Irregular.to_string(), "irregular");
281    }
282    #[test]
283    pub(super) fn test_stride_analysis_result() {
284        let unit = StrideAnalysisResult::unit("arr");
285        assert!(unit.is_vectorizable);
286        let stride2 = StrideAnalysisResult::constant("arr", 2);
287        assert!(!stride2.is_vectorizable);
288        let neg1 = StrideAnalysisResult::constant("arr", -1);
289        assert!(neg1.is_vectorizable);
290        let irregular = StrideAnalysisResult::irregular("arr");
291        assert!(!irregular.is_vectorizable);
292    }
293    #[test]
294    pub(super) fn test_vectorization_pipeline() {
295        use crate::lcnf::*;
296        let decl = LcnfFunDecl {
297            name: "loop_sum".to_string(),
298            original_name: None,
299            params: vec![
300                LcnfParam {
301                    id: LcnfVarId(0),
302                    ty: LcnfType::Nat,
303                    name: "i".to_string(),
304                    erased: false,
305                    borrowed: false,
306                },
307                LcnfParam {
308                    id: LcnfVarId(1),
309                    ty: LcnfType::Nat,
310                    name: "acc".to_string(),
311                    erased: false,
312                    borrowed: false,
313                },
314            ],
315            ret_type: LcnfType::Nat,
316            body: LcnfExpr::Let {
317                id: LcnfVarId(2),
318                name: "bound".to_string(),
319                ty: LcnfType::Nat,
320                value: LcnfLetValue::Lit(LcnfLit::Nat(1024)),
321                body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1)))),
322            },
323            is_recursive: true,
324            is_lifted: false,
325            inline_cost: 10,
326        };
327        let pipeline = VectorizationPipeline::new();
328        let mut decls = vec![decl];
329        let result = pipeline.run(&mut decls);
330        assert!(result.report.loops_analyzed >= 0);
331    }
332    #[test]
333    pub(super) fn test_latency_ordering() {
334        assert!(simd_op_latency(&SIMDOp::Add) < simd_op_latency(&SIMDOp::Sqrt));
335        assert!(simd_op_latency(&SIMDOp::Mul) <= simd_op_latency(&SIMDOp::Div));
336        assert_eq!(
337            simd_op_latency(&SIMDOp::Broadcast),
338            LatencyClass::SingleCycle
339        );
340        assert_eq!(simd_op_latency(&SIMDOp::Load), LatencyClass::Memory);
341    }
342    #[test]
343    pub(super) fn test_reduction_kind_display() {
344        assert_eq!(ReductionKind::Sum.to_string(), "sum");
345        assert_eq!(ReductionKind::DotProduct.to_string(), "dot_product");
346        assert_eq!(ReductionKind::Min.to_string(), "min");
347    }
348    #[test]
349    pub(super) fn test_dependence_kind_display() {
350        assert_eq!(DependenceKind::True.to_string(), "RAW");
351        assert_eq!(DependenceKind::Anti.to_string(), "WAR");
352        assert_eq!(DependenceKind::Output.to_string(), "WAW");
353        assert_eq!(DependenceKind::Input.to_string(), "RAR");
354    }
355    #[test]
356    pub(super) fn test_loop_transform_config_default() {
357        let cfg = LoopTransformConfig::default();
358        assert_eq!(cfg.unroll_factor, 4);
359        assert_eq!(cfg.tile_size, 64);
360        assert!(cfg.strip_mine);
361    }
362    #[test]
363    pub(super) fn test_vector_instr_builder_blend_cmp() {
364        let mut builder = VectorInstrBuilder::new(VectorWidth::W256);
365        let a = builder.load("a_ptr");
366        let b = builder.load("b_ptr");
367        let mask = builder.cmp(CmpOp::Lt, &a, &b);
368        let _blended = builder.blend(&a, &b, &mask);
369        let instrs = builder.build();
370        let has_cmp = instrs.iter().any(|i| matches!(i.op, SIMDOp::Compare(_)));
371        let has_blend = instrs.iter().any(|i| i.op == SIMDOp::Blend);
372        assert!(has_cmp);
373        assert!(has_blend);
374    }
375}
376#[cfg(test)]
377mod Vec_infra_tests {
378    use super::*;
379    #[test]
380    pub(super) fn test_pass_config() {
381        let config = VecPassConfig::new("test_pass", VecPassPhase::Transformation);
382        assert!(config.enabled);
383        assert!(config.phase.is_modifying());
384        assert_eq!(config.phase.name(), "transformation");
385    }
386    #[test]
387    pub(super) fn test_pass_stats() {
388        let mut stats = VecPassStats::new();
389        stats.record_run(10, 100, 3);
390        stats.record_run(20, 200, 5);
391        assert_eq!(stats.total_runs, 2);
392        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
393        assert!((stats.success_rate() - 1.0).abs() < 0.01);
394        let s = stats.format_summary();
395        assert!(s.contains("Runs: 2/2"));
396    }
397    #[test]
398    pub(super) fn test_pass_registry() {
399        let mut reg = VecPassRegistry::new();
400        reg.register(VecPassConfig::new("pass_a", VecPassPhase::Analysis));
401        reg.register(VecPassConfig::new("pass_b", VecPassPhase::Transformation).disabled());
402        assert_eq!(reg.total_passes(), 2);
403        assert_eq!(reg.enabled_count(), 1);
404        reg.update_stats("pass_a", 5, 50, 2);
405        let stats = reg.get_stats("pass_a").expect("stats should exist");
406        assert_eq!(stats.total_changes, 5);
407    }
408    #[test]
409    pub(super) fn test_analysis_cache() {
410        let mut cache = VecAnalysisCache::new(10);
411        cache.insert("key1".to_string(), vec![1, 2, 3]);
412        assert!(cache.get("key1").is_some());
413        assert!(cache.get("key2").is_none());
414        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
415        cache.invalidate("key1");
416        assert!(!cache.entries["key1"].valid);
417        assert_eq!(cache.size(), 1);
418    }
419    #[test]
420    pub(super) fn test_worklist() {
421        let mut wl = VecWorklist::new();
422        assert!(wl.push(1));
423        assert!(wl.push(2));
424        assert!(!wl.push(1));
425        assert_eq!(wl.len(), 2);
426        assert_eq!(wl.pop(), Some(1));
427        assert!(!wl.contains(1));
428        assert!(wl.contains(2));
429    }
430    #[test]
431    pub(super) fn test_dominator_tree() {
432        let mut dt = VecDominatorTree::new(5);
433        dt.set_idom(1, 0);
434        dt.set_idom(2, 0);
435        dt.set_idom(3, 1);
436        assert!(dt.dominates(0, 3));
437        assert!(dt.dominates(1, 3));
438        assert!(!dt.dominates(2, 3));
439        assert!(dt.dominates(3, 3));
440    }
441    #[test]
442    pub(super) fn test_liveness() {
443        let mut liveness = VecLivenessInfo::new(3);
444        liveness.add_def(0, 1);
445        liveness.add_use(1, 1);
446        assert!(liveness.defs[0].contains(&1));
447        assert!(liveness.uses[1].contains(&1));
448    }
449    #[test]
450    pub(super) fn test_constant_folding() {
451        assert_eq!(VecConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
452        assert_eq!(VecConstantFoldingHelper::fold_div_i64(10, 0), None);
453        assert_eq!(VecConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
454        assert_eq!(
455            VecConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
456            0b1000
457        );
458        assert_eq!(VecConstantFoldingHelper::fold_bitnot_i64(0), -1);
459    }
460    #[test]
461    pub(super) fn test_dep_graph() {
462        let mut g = VecDepGraph::new();
463        g.add_dep(1, 2);
464        g.add_dep(2, 3);
465        g.add_dep(1, 3);
466        assert_eq!(g.dependencies_of(2), vec![1]);
467        let topo = g.topological_sort();
468        assert_eq!(topo.len(), 3);
469        assert!(!g.has_cycle());
470        let pos: std::collections::HashMap<u32, usize> =
471            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
472        assert!(pos[&1] < pos[&2]);
473        assert!(pos[&1] < pos[&3]);
474        assert!(pos[&2] < pos[&3]);
475    }
476}