1use crate::lcnf::{LcnfExpr, LcnfFunDecl, LcnfLetValue};
6use std::collections::HashMap;
7
8use super::types::{
9 CmpOp, DependenceGraph, DependenceKind, LatencyClass, LoopTransformConfig, LoopTransformer,
10 ReductionInfo, ReductionKind, SIMDCostModel, SIMDOp, SIMDTarget, SIMDTargetInfo,
11 StrideAnalysisResult, StridePattern, VecAnalysisCache, VecConstantFoldingHelper, VecDepGraph,
12 VecDominatorTree, VecLivenessInfo, VecPassConfig, VecPassPhase, VecPassRegistry, VecPassStats,
13 VecWorklist, VectorInstr, VectorInstrBuilder, VectorPrologueEpilogue, VectorRegisterFile,
14 VectorScheduler, VectorWidth, VectorizationAnalysis, VectorizationCandidate,
15 VectorizationConfig, VectorizationHint, VectorizationPass, VectorizationPipeline,
16 VectorizationReport,
17};
18
19#[cfg(test)]
20mod tests {
21 use super::*;
22 #[test]
23 pub(super) fn vector_width_lanes() {
24 assert_eq!(VectorWidth::W128.lanes_f32(), 4);
25 assert_eq!(VectorWidth::W256.lanes_f32(), 8);
26 assert_eq!(VectorWidth::W512.lanes_f32(), 16);
27 assert_eq!(VectorWidth::W256.lanes_f64(), 4);
28 }
29 #[test]
30 pub(super) fn simd_target_max_width() {
31 assert_eq!(SIMDTarget::X86AVX.max_width(), VectorWidth::W256);
32 assert_eq!(SIMDTarget::X86AVX512.max_width(), VectorWidth::W512);
33 assert_eq!(SIMDTarget::ArmNeon.max_width(), VectorWidth::W128);
34 }
35 #[test]
36 pub(super) fn candidate_no_dep() {
37 let c = VectorizationCandidate {
38 func_name: "loop_add".to_string(),
39 loop_var: "i".to_string(),
40 loop_bound: Some(1024),
41 array_reads: vec!["a".to_string()],
42 array_writes: vec!["b".to_string()],
43 is_inner_loop: true,
44 has_loop_carried_dep: false,
45 };
46 let analysis = VectorizationAnalysis::new();
47 assert!(analysis.can_vectorize(&c));
48 let speedup = analysis.estimate_speedup(&c, VectorWidth::W256);
49 assert!(speedup > 1.0, "speedup={}", speedup);
50 }
51 #[test]
52 pub(super) fn candidate_with_dep_rejected() {
53 let c = VectorizationCandidate {
54 func_name: "loop_reduce".to_string(),
55 loop_var: "i".to_string(),
56 loop_bound: Some(256),
57 array_reads: vec!["acc".to_string()],
58 array_writes: vec!["acc".to_string()],
59 is_inner_loop: true,
60 has_loop_carried_dep: true,
61 };
62 let analysis = VectorizationAnalysis::new();
63 assert!(!analysis.can_vectorize(&c));
64 assert_eq!(analysis.estimate_speedup(&c, VectorWidth::W256), 1.0);
65 }
66 #[test]
67 pub(super) fn emit_vector_loop_fma() {
68 let config = VectorizationConfig {
69 enable_fma: true,
70 target: SIMDTarget::X86AVX,
71 ..VectorizationConfig::default()
72 };
73 let pass = VectorizationPass::new(config);
74 let candidate = VectorizationCandidate {
75 func_name: "dot_product".to_string(),
76 loop_var: "i".to_string(),
77 loop_bound: Some(512),
78 array_reads: vec!["a".to_string(), "b".to_string()],
79 array_writes: vec!["result".to_string()],
80 is_inner_loop: true,
81 has_loop_carried_dep: false,
82 };
83 let instrs = pass.emit_vector_loop(&candidate, VectorWidth::W256);
84 assert!(!instrs.is_empty());
85 let has_fma = instrs.iter().any(|i| i.op == SIMDOp::Fma);
86 assert!(has_fma, "expected FMA instruction");
87 }
88 #[test]
89 pub(super) fn vector_instr_display() {
90 let instr = VectorInstr::new(
91 SIMDOp::Add,
92 VectorWidth::W128,
93 "v0",
94 vec!["v1".to_string(), "v2".to_string()],
95 );
96 let s = format!("{}", instr);
97 assert!(s.contains("vadd"));
98 assert!(s.contains("128"));
99 }
100 #[test]
101 pub(super) fn report_merge() {
102 let mut r1 = VectorizationReport {
103 loops_analyzed: 3,
104 loops_vectorized: 2,
105 rejected_dep: 1,
106 ..VectorizationReport::default()
107 };
108 let r2 = VectorizationReport {
109 loops_analyzed: 2,
110 loops_vectorized: 1,
111 rejected_trip_count: 1,
112 ..VectorizationReport::default()
113 };
114 r1.merge(&r2);
115 assert_eq!(r1.loops_analyzed, 5);
116 assert_eq!(r1.loops_vectorized, 3);
117 assert_eq!(r1.rejected_dep, 1);
118 assert_eq!(r1.rejected_trip_count, 1);
119 }
120 #[test]
121 pub(super) fn effective_width_caps_at_target() {
122 let config = VectorizationConfig {
123 preferred_width: VectorWidth::W512,
124 target: SIMDTarget::X86SSE,
125 ..VectorizationConfig::default()
126 };
127 let pass = VectorizationPass::new(config);
128 assert_eq!(pass.effective_width(), VectorWidth::W128);
129 }
130}
131#[allow(dead_code)]
133pub fn simd_op_latency(op: &SIMDOp) -> LatencyClass {
134 match op {
135 SIMDOp::Broadcast => LatencyClass::SingleCycle,
136 SIMDOp::Add | SIMDOp::Sub => LatencyClass::Short,
137 SIMDOp::Mul => LatencyClass::Short,
138 SIMDOp::Div => LatencyClass::Medium,
139 SIMDOp::Sqrt => LatencyClass::Long,
140 SIMDOp::Fma => LatencyClass::Short,
141 SIMDOp::Load | SIMDOp::Store => LatencyClass::Memory,
142 SIMDOp::Shuffle | SIMDOp::Blend => LatencyClass::Short,
143 SIMDOp::Compare(_) => LatencyClass::Short,
144 SIMDOp::Min | SIMDOp::Max => LatencyClass::Short,
145 SIMDOp::HorizontalAdd => LatencyClass::Medium,
146 }
147}
148#[allow(dead_code)]
150pub type HintMap = HashMap<String, Vec<VectorizationHint>>;
151#[cfg(test)]
152mod extended_tests {
153 use super::*;
154 #[test]
155 pub(super) fn test_vector_register_file_alloc() {
156 let mut rf = VectorRegisterFile::new(4);
157 let r0 = rf.alloc("v");
158 let r1 = rf.alloc("v");
159 let r2 = rf.alloc("v");
160 let r3 = rf.alloc("v");
161 assert_eq!(rf.allocation.len(), 4);
162 assert!(rf.is_full());
163 let _r4 = rf.alloc("v");
164 assert!(rf.spill_count() > 0);
165 rf.free(&r0);
166 assert!(!rf.is_full());
167 let _ = (r1, r2, r3);
168 }
169 #[test]
170 pub(super) fn test_vector_scheduler_ordering() {
171 let mut builder = VectorInstrBuilder::new(VectorWidth::W256);
172 let load_a = builder.load("a_ptr");
173 let load_b = builder.load("b_ptr");
174 let mul = builder.mul(&load_a, &load_b);
175 let _hadd = builder.hadd(&mul);
176 let instrs = builder.build();
177 let scheduled = VectorScheduler::schedule(&instrs);
178 assert_eq!(scheduled.len(), instrs.len());
179 let makespan = VectorScheduler::makespan(&scheduled);
180 assert!(makespan > 0);
181 }
182 #[test]
183 pub(super) fn test_simd_cost_model() {
184 let model = SIMDCostModel::default();
185 let candidate = VectorizationCandidate {
186 func_name: "test".into(),
187 loop_var: "i".into(),
188 loop_bound: Some(1024),
189 array_reads: vec!["a".into(), "b".into()],
190 array_writes: vec!["c".into()],
191 is_inner_loop: true,
192 has_loop_carried_dep: false,
193 };
194 let gain = model.throughput_gain(&candidate, VectorWidth::W256);
195 assert!(gain > 0.0);
196 }
197 #[test]
198 pub(super) fn test_dependence_graph() {
199 let mut dg = DependenceGraph::default();
200 dg.add_edge("a", "b", DependenceKind::True, 0);
201 dg.add_edge("b", "c", DependenceKind::Anti, 1);
202 dg.add_edge("c", "c", DependenceKind::Output, 2);
203 assert!(dg.has_carried_dependence());
204 assert_eq!(dg.max_distance(), 2);
205 assert_eq!(dg.edges_of_kind(DependenceKind::True).len(), 1);
206 assert_eq!(dg.edges_of_kind(DependenceKind::Anti).len(), 1);
207 }
208 #[test]
209 pub(super) fn test_loop_transformer() {
210 let candidate = VectorizationCandidate::new("my_loop", "i");
211 let transformer = LoopTransformer::new();
212 let result = transformer.transform(&candidate, VectorWidth::W256);
213 assert!(result.transformed_name.contains("my_loop"));
214 assert!(result.strip_mined);
215 assert!(result.vector_instr_count > 0);
216 }
217 #[test]
218 pub(super) fn test_vectorization_hints_display() {
219 assert_eq!(VectorizationHint::Force.to_string(), "#[vectorize(force)]");
220 assert_eq!(
221 VectorizationHint::Disable.to_string(),
222 "#[vectorize(disable)]"
223 );
224 assert_eq!(
225 VectorizationHint::Unroll(4).to_string(),
226 "#[vectorize(unroll=4)]"
227 );
228 assert_eq!(
229 VectorizationHint::Width(VectorWidth::W256).to_string(),
230 "#[vectorize(width=256)]"
231 );
232 }
233 #[test]
234 pub(super) fn test_reduction_info() {
235 let sum = ReductionInfo::sum("acc");
236 assert_eq!(sum.kind, ReductionKind::Sum);
237 assert_eq!(sum.initial_value, 0);
238 assert_eq!(sum.reduction_op(), SIMDOp::Add);
239 let prod = ReductionInfo::product("p");
240 assert_eq!(prod.initial_value, 1);
241 assert_eq!(prod.reduction_op(), SIMDOp::Mul);
242 }
243 #[test]
244 pub(super) fn test_vector_instr_builder() {
245 let mut builder = VectorInstrBuilder::new(VectorWidth::W128);
246 let a = builder.load("a_ptr");
247 let b = builder.load("b_ptr");
248 let c = builder.broadcast("scalar");
249 let fma_r = builder.fma(&a, &b, &c);
250 let _hadd = builder.hadd(&fma_r);
251 let instrs = builder.build();
252 assert!(!instrs.is_empty());
253 let has_fma = instrs.iter().any(|i| i.op == SIMDOp::Fma);
254 assert!(has_fma);
255 }
256 #[test]
257 pub(super) fn test_simd_target_info() {
258 let avx512 = SIMDTargetInfo::new(SIMDTarget::X86AVX512);
259 assert_eq!(avx512.num_vector_registers(), 16);
260 assert!(avx512.supports_masking());
261 assert!(avx512.supports_scatter());
262 assert_eq!(avx512.preferred_alignment(), 32);
263 let neon = SIMDTargetInfo::new(SIMDTarget::ArmNeon);
264 assert_eq!(neon.num_vector_registers(), 32);
265 assert!(neon.supports_masking());
266 assert!(!neon.supports_gather());
267 }
268 #[test]
269 pub(super) fn test_prologue_epilogue() {
270 let pe = VectorPrologueEpilogue::new(VectorWidth::W256);
271 assert_eq!(pe.prologue_iterations(0, 4), 0);
272 assert_eq!(pe.prologue_iterations(12, 4), 5);
273 assert_eq!(pe.epilogue_iterations(100, 4), 0);
274 assert_eq!(pe.epilogue_iterations(101, 4), 1);
275 }
276 #[test]
277 pub(super) fn test_stride_pattern_display() {
278 assert_eq!(StridePattern::Unit.to_string(), "unit");
279 assert_eq!(StridePattern::Constant(2).to_string(), "const(2)");
280 assert_eq!(StridePattern::Irregular.to_string(), "irregular");
281 }
282 #[test]
283 pub(super) fn test_stride_analysis_result() {
284 let unit = StrideAnalysisResult::unit("arr");
285 assert!(unit.is_vectorizable);
286 let stride2 = StrideAnalysisResult::constant("arr", 2);
287 assert!(!stride2.is_vectorizable);
288 let neg1 = StrideAnalysisResult::constant("arr", -1);
289 assert!(neg1.is_vectorizable);
290 let irregular = StrideAnalysisResult::irregular("arr");
291 assert!(!irregular.is_vectorizable);
292 }
293 #[test]
294 pub(super) fn test_vectorization_pipeline() {
295 use crate::lcnf::*;
296 let decl = LcnfFunDecl {
297 name: "loop_sum".to_string(),
298 original_name: None,
299 params: vec![
300 LcnfParam {
301 id: LcnfVarId(0),
302 ty: LcnfType::Nat,
303 name: "i".to_string(),
304 erased: false,
305 borrowed: false,
306 },
307 LcnfParam {
308 id: LcnfVarId(1),
309 ty: LcnfType::Nat,
310 name: "acc".to_string(),
311 erased: false,
312 borrowed: false,
313 },
314 ],
315 ret_type: LcnfType::Nat,
316 body: LcnfExpr::Let {
317 id: LcnfVarId(2),
318 name: "bound".to_string(),
319 ty: LcnfType::Nat,
320 value: LcnfLetValue::Lit(LcnfLit::Nat(1024)),
321 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1)))),
322 },
323 is_recursive: true,
324 is_lifted: false,
325 inline_cost: 10,
326 };
327 let pipeline = VectorizationPipeline::new();
328 let mut decls = vec![decl];
329 let result = pipeline.run(&mut decls);
330 assert!(result.report.loops_analyzed >= 0);
331 }
332 #[test]
333 pub(super) fn test_latency_ordering() {
334 assert!(simd_op_latency(&SIMDOp::Add) < simd_op_latency(&SIMDOp::Sqrt));
335 assert!(simd_op_latency(&SIMDOp::Mul) <= simd_op_latency(&SIMDOp::Div));
336 assert_eq!(
337 simd_op_latency(&SIMDOp::Broadcast),
338 LatencyClass::SingleCycle
339 );
340 assert_eq!(simd_op_latency(&SIMDOp::Load), LatencyClass::Memory);
341 }
342 #[test]
343 pub(super) fn test_reduction_kind_display() {
344 assert_eq!(ReductionKind::Sum.to_string(), "sum");
345 assert_eq!(ReductionKind::DotProduct.to_string(), "dot_product");
346 assert_eq!(ReductionKind::Min.to_string(), "min");
347 }
348 #[test]
349 pub(super) fn test_dependence_kind_display() {
350 assert_eq!(DependenceKind::True.to_string(), "RAW");
351 assert_eq!(DependenceKind::Anti.to_string(), "WAR");
352 assert_eq!(DependenceKind::Output.to_string(), "WAW");
353 assert_eq!(DependenceKind::Input.to_string(), "RAR");
354 }
355 #[test]
356 pub(super) fn test_loop_transform_config_default() {
357 let cfg = LoopTransformConfig::default();
358 assert_eq!(cfg.unroll_factor, 4);
359 assert_eq!(cfg.tile_size, 64);
360 assert!(cfg.strip_mine);
361 }
362 #[test]
363 pub(super) fn test_vector_instr_builder_blend_cmp() {
364 let mut builder = VectorInstrBuilder::new(VectorWidth::W256);
365 let a = builder.load("a_ptr");
366 let b = builder.load("b_ptr");
367 let mask = builder.cmp(CmpOp::Lt, &a, &b);
368 let _blended = builder.blend(&a, &b, &mask);
369 let instrs = builder.build();
370 let has_cmp = instrs.iter().any(|i| matches!(i.op, SIMDOp::Compare(_)));
371 let has_blend = instrs.iter().any(|i| i.op == SIMDOp::Blend);
372 assert!(has_cmp);
373 assert!(has_blend);
374 }
375}
376#[cfg(test)]
377mod Vec_infra_tests {
378 use super::*;
379 #[test]
380 pub(super) fn test_pass_config() {
381 let config = VecPassConfig::new("test_pass", VecPassPhase::Transformation);
382 assert!(config.enabled);
383 assert!(config.phase.is_modifying());
384 assert_eq!(config.phase.name(), "transformation");
385 }
386 #[test]
387 pub(super) fn test_pass_stats() {
388 let mut stats = VecPassStats::new();
389 stats.record_run(10, 100, 3);
390 stats.record_run(20, 200, 5);
391 assert_eq!(stats.total_runs, 2);
392 assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
393 assert!((stats.success_rate() - 1.0).abs() < 0.01);
394 let s = stats.format_summary();
395 assert!(s.contains("Runs: 2/2"));
396 }
397 #[test]
398 pub(super) fn test_pass_registry() {
399 let mut reg = VecPassRegistry::new();
400 reg.register(VecPassConfig::new("pass_a", VecPassPhase::Analysis));
401 reg.register(VecPassConfig::new("pass_b", VecPassPhase::Transformation).disabled());
402 assert_eq!(reg.total_passes(), 2);
403 assert_eq!(reg.enabled_count(), 1);
404 reg.update_stats("pass_a", 5, 50, 2);
405 let stats = reg.get_stats("pass_a").expect("stats should exist");
406 assert_eq!(stats.total_changes, 5);
407 }
408 #[test]
409 pub(super) fn test_analysis_cache() {
410 let mut cache = VecAnalysisCache::new(10);
411 cache.insert("key1".to_string(), vec![1, 2, 3]);
412 assert!(cache.get("key1").is_some());
413 assert!(cache.get("key2").is_none());
414 assert!((cache.hit_rate() - 0.5).abs() < 0.01);
415 cache.invalidate("key1");
416 assert!(!cache.entries["key1"].valid);
417 assert_eq!(cache.size(), 1);
418 }
419 #[test]
420 pub(super) fn test_worklist() {
421 let mut wl = VecWorklist::new();
422 assert!(wl.push(1));
423 assert!(wl.push(2));
424 assert!(!wl.push(1));
425 assert_eq!(wl.len(), 2);
426 assert_eq!(wl.pop(), Some(1));
427 assert!(!wl.contains(1));
428 assert!(wl.contains(2));
429 }
430 #[test]
431 pub(super) fn test_dominator_tree() {
432 let mut dt = VecDominatorTree::new(5);
433 dt.set_idom(1, 0);
434 dt.set_idom(2, 0);
435 dt.set_idom(3, 1);
436 assert!(dt.dominates(0, 3));
437 assert!(dt.dominates(1, 3));
438 assert!(!dt.dominates(2, 3));
439 assert!(dt.dominates(3, 3));
440 }
441 #[test]
442 pub(super) fn test_liveness() {
443 let mut liveness = VecLivenessInfo::new(3);
444 liveness.add_def(0, 1);
445 liveness.add_use(1, 1);
446 assert!(liveness.defs[0].contains(&1));
447 assert!(liveness.uses[1].contains(&1));
448 }
449 #[test]
450 pub(super) fn test_constant_folding() {
451 assert_eq!(VecConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
452 assert_eq!(VecConstantFoldingHelper::fold_div_i64(10, 0), None);
453 assert_eq!(VecConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
454 assert_eq!(
455 VecConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
456 0b1000
457 );
458 assert_eq!(VecConstantFoldingHelper::fold_bitnot_i64(0), -1);
459 }
460 #[test]
461 pub(super) fn test_dep_graph() {
462 let mut g = VecDepGraph::new();
463 g.add_dep(1, 2);
464 g.add_dep(2, 3);
465 g.add_dep(1, 3);
466 assert_eq!(g.dependencies_of(2), vec![1]);
467 let topo = g.topological_sort();
468 assert_eq!(topo.len(), 3);
469 assert!(!g.has_cycle());
470 let pos: std::collections::HashMap<u32, usize> =
471 topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
472 assert!(pos[&1] < pos[&2]);
473 assert!(pos[&1] < pos[&3]);
474 assert!(pos[&2] < pos[&3]);
475 }
476}