1use crate::{
19 AccessPattern, BinOp, Body, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, LoopType, Op,
20 ScalarType, Stmt, TargetArch, TripCount, Value,
21};
22use rustc_hash::FxHashMap;
23use thiserror::Error;
24
25#[derive(Clone, Debug, Error)]
27pub enum VectorizeError {
28 #[error("loop {loop_id:?} cannot be vectorized: {reason}")]
30 NotVectorizable {
31 loop_id: LoopId,
33 reason: String,
35 },
36
37 #[error("invalid vector width {width} for type {ty:?}")]
39 InvalidWidth {
40 width: u8,
42 ty: ScalarType,
44 },
45}
46
47#[derive(Clone, Debug)]
49pub struct VectorizationInfo {
50 pub vectorizable: bool,
52 pub reason: Option<String>,
54 pub recommended_width: u8,
56 pub access_patterns: Vec<AccessPattern>,
58 pub has_fma: bool,
60 pub has_reduction: bool,
62}
63
64impl Default for VectorizationInfo {
65 fn default() -> Self {
66 Self {
67 vectorizable: false,
68 reason: Some("not analyzed".to_string()),
69 recommended_width: 1,
70 access_patterns: Vec::new(),
71 has_fma: false,
72 has_reduction: false,
73 }
74 }
75}
76
77#[derive(Clone, Debug)]
79pub struct VectorizeConfig {
80 pub target: TargetArch,
82 pub forced_width: u8,
84 pub generate_remainder: bool,
86 pub enable_fma: bool,
88 pub min_trip_count: usize,
90}
91
92impl Default for VectorizeConfig {
93 fn default() -> Self {
94 Self {
95 target: TargetArch::default(),
96 forced_width: 0,
97 generate_remainder: true,
98 enable_fma: true,
99 min_trip_count: 4,
100 }
101 }
102}
103
104pub struct VectorizePass {
106 config: VectorizeConfig,
107 analysis: FxHashMap<LoopId, VectorizationInfo>,
109}
110
111impl VectorizePass {
112 pub fn new(config: VectorizeConfig) -> Self {
114 Self {
115 config,
116 analysis: FxHashMap::default(),
117 }
118 }
119
120 pub fn analyze(&mut self, ir: &LoopIR) -> FxHashMap<LoopId, VectorizationInfo> {
122 self.analysis.clear();
123
124 for stmt in &ir.body.stmts {
125 self.analyze_stmt(stmt, &ir.loop_info);
126 }
127
128 self.analysis.clone()
129 }
130
131 fn analyze_stmt(&mut self, stmt: &Stmt, loop_info: &[LoopMetadata]) {
133 if let Stmt::Loop(lp) = stmt {
134 let info = self.analyze_loop(lp, loop_info);
135 self.analysis.insert(lp.id, info);
136
137 for inner_stmt in &lp.body.stmts {
139 self.analyze_stmt(inner_stmt, loop_info);
140 }
141 }
142 }
143
144 fn analyze_loop(&self, lp: &Loop, loop_info: &[LoopMetadata]) -> VectorizationInfo {
146 let mut info = VectorizationInfo::default();
147
148 if !lp.attrs.contains(LoopAttrs::VECTORIZE) {
150 info.reason = Some("loop not marked VECTORIZE".to_string());
151 return info;
152 }
153
154 let metadata = loop_info.iter().find(|m| m.id == lp.id);
156 let trip_count = metadata.map(|m| &m.trip_count);
157
158 match trip_count {
159 Some(TripCount::Static(n)) if *n < self.config.min_trip_count => {
160 info.reason = Some(format!(
161 "trip count {} below threshold {}",
162 n, self.config.min_trip_count
163 ));
164 return info;
165 }
166 Some(TripCount::Dynamic) => {
167 }
170 _ => {}
171 }
172
173 let (patterns, has_fma, has_reduction) = self.analyze_loop_body(&lp.body);
175 info.access_patterns = patterns.clone();
176 info.has_fma = has_fma;
177 info.has_reduction = has_reduction;
178
179 let all_sequential = patterns
181 .iter()
182 .all(|p| matches!(p, AccessPattern::Sequential | AccessPattern::Broadcast));
183
184 if !all_sequential {
185 info.reason = Some("non-sequential access pattern".to_string());
186 return info;
187 }
188
189 let elem_type = self.infer_element_type(&lp.body);
191 let width = if self.config.forced_width > 0 {
192 self.config.forced_width
193 } else {
194 LoopType::natural_vector_width(elem_type, self.config.target)
195 };
196
197 info.vectorizable = width > 1;
198 info.recommended_width = width;
199 info.reason = None;
200
201 info
202 }
203
204 fn analyze_loop_body(&self, body: &Body) -> (Vec<AccessPattern>, bool, bool) {
206 let mut patterns = Vec::new();
207 let mut has_fma = false;
208 let mut has_reduction = false;
209
210 for stmt in &body.stmts {
211 match stmt {
212 Stmt::Assign(_, op) => {
213 if let Op::Load(mem_ref) = op {
215 patterns.push(mem_ref.access.clone());
216 }
217
218 if self.config.enable_fma {
220 has_fma |= self.is_fma_opportunity(op);
221 }
222
223 if let Op::VecReduce(_, _) = op {
225 has_reduction = true;
226 }
227 }
228 Stmt::Store(mem_ref, _) => {
229 patterns.push(mem_ref.access.clone());
230 }
231 Stmt::Loop(inner)
232 if inner.attrs.contains(LoopAttrs::REDUCTION) => {
234 has_reduction = true;
235 }
236 _ => {}
237 }
238 }
239
240 (patterns, has_fma, has_reduction)
241 }
242
243 fn is_fma_opportunity(&self, op: &Op) -> bool {
245 match op {
247 Op::Binary(BinOp::Add, _, _) => {
248 false
251 }
252 _ => false,
253 }
254 }
255
256 fn infer_element_type(&self, body: &Body) -> ScalarType {
258 for stmt in &body.stmts {
259 if let Stmt::Assign(_, Op::Load(mem_ref)) = stmt {
260 if let LoopType::Scalar(s) = &mem_ref.elem_ty {
261 return *s;
262 }
263 }
264 }
265 ScalarType::Float(32) }
267
268 pub fn vectorize(&self, ir: &mut LoopIR) -> Result<VectorizeReport, VectorizeError> {
270 let mut report = VectorizeReport::default();
271
272 for stmt in &mut ir.body.stmts {
273 self.vectorize_stmt(stmt, &mut ir.loop_info, &mut report)?;
274 }
275
276 Ok(report)
277 }
278
279 fn vectorize_stmt(
281 &self,
282 stmt: &mut Stmt,
283 loop_info: &mut [LoopMetadata],
284 report: &mut VectorizeReport,
285 ) -> Result<(), VectorizeError> {
286 if let Stmt::Loop(lp) = stmt {
287 if let Some(info) = self.analysis.get(&lp.id) {
288 if info.vectorizable {
289 self.vectorize_loop(lp, info, loop_info, report)?;
290 }
291 }
292
293 for inner_stmt in &mut lp.body.stmts {
295 self.vectorize_stmt(inner_stmt, loop_info, report)?;
296 }
297 }
298 Ok(())
299 }
300
301 fn vectorize_loop(
303 &self,
304 lp: &mut Loop,
305 info: &VectorizationInfo,
306 loop_info: &mut [LoopMetadata],
307 report: &mut VectorizeReport,
308 ) -> Result<(), VectorizeError> {
309 let width = info.recommended_width;
310
311 lp.step = Value::i64(width as i64);
313
314 if let Some(meta) = loop_info.iter_mut().find(|m| m.id == lp.id) {
316 meta.vector_width = Some(width);
317 }
318
319 self.vectorize_body(&mut lp.body, width)?;
321
322 report.vectorized_loops.push(VectorizedLoopInfo {
324 loop_id: lp.id,
325 vector_width: width,
326 has_fma: info.has_fma,
327 has_reduction: info.has_reduction,
328 });
329
330 Ok(())
331 }
332
333 fn vectorize_body(&self, body: &mut Body, width: u8) -> Result<(), VectorizeError> {
335 for stmt in &mut body.stmts {
336 if let Stmt::Assign(_, op) = stmt {
337 *op = self.vectorize_op(op, width)?;
338 }
339 }
340 Ok(())
341 }
342
343 fn vectorize_op(&self, op: &Op, width: u8) -> Result<Op, VectorizeError> {
345 match op {
346 Op::Load(mem_ref) => {
347 let mut vec_ref = mem_ref.clone();
349 if let LoopType::Scalar(s) = &mem_ref.elem_ty {
350 vec_ref.elem_ty = LoopType::Vector(*s, width);
351 }
352 Ok(Op::Load(vec_ref))
353 }
354
355 Op::Binary(bin_op, a, b) => {
356 let vec_a = self.vectorize_value(a, width);
358 let vec_b = self.vectorize_value(b, width);
359 Ok(Op::Binary(*bin_op, vec_a, vec_b))
360 }
361
362 Op::Unary(un_op, a) => {
363 let vec_a = self.vectorize_value(a, width);
364 Ok(Op::Unary(*un_op, vec_a))
365 }
366
367 Op::Fma(a, b, c) => {
369 let vec_a = self.vectorize_value(a, width);
370 let vec_b = self.vectorize_value(b, width);
371 let vec_c = self.vectorize_value(c, width);
372 Ok(Op::Fma(vec_a, vec_b, vec_c))
373 }
374
375 _ => Ok(op.clone()),
377 }
378 }
379
380 fn vectorize_value(&self, val: &Value, width: u8) -> Value {
382 match val {
383 Value::Var(id, LoopType::Scalar(s)) => Value::Var(*id, LoopType::Vector(*s, width)),
384 Value::FloatConst(f, s) => {
385 Value::FloatConst(*f, *s)
387 }
388 Value::IntConst(i, s) => Value::IntConst(*i, *s),
389 _ => val.clone(),
390 }
391 }
392}
393
394#[derive(Clone, Debug, Default)]
396pub struct VectorizeReport {
397 pub vectorized_loops: Vec<VectorizedLoopInfo>,
399 pub failed_loops: Vec<(LoopId, String)>,
401}
402
403impl VectorizeReport {
404 pub fn any_vectorized(&self) -> bool {
406 !self.vectorized_loops.is_empty()
407 }
408
409 pub fn count(&self) -> usize {
411 self.vectorized_loops.len()
412 }
413}
414
415impl std::fmt::Display for VectorizeReport {
416 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417 writeln!(f, "Vectorization Report")?;
418 writeln!(f, "====================")?;
419 writeln!(f, "Vectorized loops: {}", self.vectorized_loops.len())?;
420
421 for info in &self.vectorized_loops {
422 writeln!(
423 f,
424 " Loop {:?}: width={}, fma={}, reduction={}",
425 info.loop_id, info.vector_width, info.has_fma, info.has_reduction
426 )?;
427 }
428
429 if !self.failed_loops.is_empty() {
430 writeln!(f, "\nFailed loops: {}", self.failed_loops.len())?;
431 for (id, reason) in &self.failed_loops {
432 writeln!(f, " Loop {:?}: {}", id, reason)?;
433 }
434 }
435
436 Ok(())
437 }
438}
439
440#[derive(Clone, Debug)]
442pub struct VectorizedLoopInfo {
443 pub loop_id: LoopId,
445 pub vector_width: u8,
447 pub has_fma: bool,
449 pub has_reduction: bool,
451}
452
453#[derive(Clone, Copy, Debug, PartialEq, Eq)]
461pub enum SimdIntrinsic {
462 Add,
465 Sub,
467 Mul,
469 Div,
471
472 Fmadd,
475 Fmsub,
477 Fnmadd,
479
480 Hadd,
483 HorizontalSum,
485
486 Min,
489 Max,
491
492 CmpEq,
495 CmpLt,
497 CmpLe,
499
500 Broadcast,
503 Extract,
505 Insert,
507 Shuffle,
509
510 LoadAligned,
513 LoadUnaligned,
515 StoreAligned,
517 StoreUnaligned,
519}
520
521impl SimdIntrinsic {
522 pub fn x86_name(&self, ty: ScalarType, width: u8) -> &'static str {
524 match (self, ty, width) {
525 (Self::Add, ScalarType::Float(32), 4) => "_mm_add_ps",
527 (Self::Sub, ScalarType::Float(32), 4) => "_mm_sub_ps",
528 (Self::Mul, ScalarType::Float(32), 4) => "_mm_mul_ps",
529 (Self::Div, ScalarType::Float(32), 4) => "_mm_div_ps",
530 (Self::Fmadd, ScalarType::Float(32), 4) => "_mm_fmadd_ps",
531 (Self::Min, ScalarType::Float(32), 4) => "_mm_min_ps",
532 (Self::Max, ScalarType::Float(32), 4) => "_mm_max_ps",
533 (Self::LoadAligned, ScalarType::Float(32), 4) => "_mm_load_ps",
534 (Self::StoreAligned, ScalarType::Float(32), 4) => "_mm_store_ps",
535
536 (Self::Add, ScalarType::Float(32), 8) => "_mm256_add_ps",
538 (Self::Sub, ScalarType::Float(32), 8) => "_mm256_sub_ps",
539 (Self::Mul, ScalarType::Float(32), 8) => "_mm256_mul_ps",
540 (Self::Div, ScalarType::Float(32), 8) => "_mm256_div_ps",
541 (Self::Fmadd, ScalarType::Float(32), 8) => "_mm256_fmadd_ps",
542 (Self::Min, ScalarType::Float(32), 8) => "_mm256_min_ps",
543 (Self::Max, ScalarType::Float(32), 8) => "_mm256_max_ps",
544 (Self::LoadAligned, ScalarType::Float(32), 8) => "_mm256_load_ps",
545 (Self::StoreAligned, ScalarType::Float(32), 8) => "_mm256_store_ps",
546 (Self::Hadd, ScalarType::Float(32), 8) => "_mm256_hadd_ps",
547
548 (Self::Add, ScalarType::Float(64), 2) => "_mm_add_pd",
550 (Self::Sub, ScalarType::Float(64), 2) => "_mm_sub_pd",
551 (Self::Mul, ScalarType::Float(64), 2) => "_mm_mul_pd",
552 (Self::Fmadd, ScalarType::Float(64), 2) => "_mm_fmadd_pd",
553
554 (Self::Add, ScalarType::Float(64), 4) => "_mm256_add_pd",
556 (Self::Sub, ScalarType::Float(64), 4) => "_mm256_sub_pd",
557 (Self::Mul, ScalarType::Float(64), 4) => "_mm256_mul_pd",
558 (Self::Fmadd, ScalarType::Float(64), 4) => "_mm256_fmadd_pd",
559
560 _ => "unknown_intrinsic",
561 }
562 }
563
564 pub fn arm_name(&self, ty: ScalarType, width: u8) -> &'static str {
566 match (self, ty, width) {
567 (Self::Add, ScalarType::Float(32), 4) => "vaddq_f32",
569 (Self::Sub, ScalarType::Float(32), 4) => "vsubq_f32",
570 (Self::Mul, ScalarType::Float(32), 4) => "vmulq_f32",
571 (Self::Fmadd, ScalarType::Float(32), 4) => "vfmaq_f32",
572 (Self::Min, ScalarType::Float(32), 4) => "vminq_f32",
573 (Self::Max, ScalarType::Float(32), 4) => "vmaxq_f32",
574 (Self::LoadAligned, ScalarType::Float(32), 4) => "vld1q_f32",
575 (Self::StoreAligned, ScalarType::Float(32), 4) => "vst1q_f32",
576
577 (Self::Add, ScalarType::Float(64), 2) => "vaddq_f64",
579 (Self::Sub, ScalarType::Float(64), 2) => "vsubq_f64",
580 (Self::Mul, ScalarType::Float(64), 2) => "vmulq_f64",
581 (Self::Fmadd, ScalarType::Float(64), 2) => "vfmaq_f64",
582
583 _ => "unknown_intrinsic",
584 }
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591 use crate::{MemRef, Param, ValueId};
592 use bhc_index::Idx;
593 use bhc_intern::Symbol;
594 use bhc_tensor_ir::BufferId;
595
596 fn make_vectorizable_loop(trip_count: usize) -> (LoopIR, LoopId) {
597 let loop_id = LoopId::new(0);
598 let loop_var = ValueId::new(0);
599
600 let mem_ref = MemRef {
601 buffer: BufferId::new(0),
602 index: Value::Var(loop_var, LoopType::Scalar(ScalarType::I64)),
603 elem_ty: LoopType::Scalar(ScalarType::F32),
604 access: AccessPattern::Sequential,
605 };
606
607 let mut body = Body::new();
608 let load_result = ValueId::new(1);
609 body.push(Stmt::Assign(load_result, Op::Load(mem_ref.clone())));
610
611 let mul_result = ValueId::new(2);
612 body.push(Stmt::Assign(
613 mul_result,
614 Op::Binary(
615 BinOp::Mul,
616 Value::Var(load_result, LoopType::Scalar(ScalarType::F32)),
617 Value::float(2.0, 32),
618 ),
619 ));
620
621 body.push(Stmt::Store(
622 mem_ref,
623 Value::Var(mul_result, LoopType::Scalar(ScalarType::F32)),
624 ));
625
626 let lp = Loop {
627 id: loop_id,
628 var: loop_var,
629 lower: Value::i64(0),
630 upper: Value::i64(trip_count as i64),
631 step: Value::i64(1),
632 body,
633 attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT,
634 };
635
636 let mut outer_body = Body::new();
637 outer_body.push(Stmt::Loop(lp));
638
639 let ir = LoopIR {
640 name: Symbol::intern("test_kernel"),
641 params: vec![Param {
642 name: Symbol::intern("data"),
643 ty: LoopType::Ptr(Box::new(LoopType::Scalar(ScalarType::F32))),
644 is_ptr: true,
645 }],
646 return_ty: LoopType::Void,
647 body: outer_body,
648 allocs: vec![],
649 loop_info: vec![LoopMetadata {
650 id: loop_id,
651 trip_count: TripCount::Static(trip_count),
652 vector_width: None,
653 parallel_chunk: None,
654 unroll_factor: None,
655 dependencies: Vec::new(),
656 }],
657 };
658
659 (ir, loop_id)
660 }
661
662 #[test]
663 fn test_vectorization_analysis() {
664 let (ir, loop_id) = make_vectorizable_loop(1024);
665
666 let mut pass = VectorizePass::new(VectorizeConfig::default());
667 let analysis = pass.analyze(&ir);
668
669 let info = analysis.get(&loop_id).expect("loop should be analyzed");
670 assert!(info.vectorizable, "loop should be vectorizable");
671 assert!(
672 info.recommended_width > 1,
673 "should recommend vector width > 1"
674 );
675 }
676
677 #[test]
678 fn test_vectorization_below_threshold() {
679 let (ir, loop_id) = make_vectorizable_loop(2); let mut pass = VectorizePass::new(VectorizeConfig::default());
682 let analysis = pass.analyze(&ir);
683
684 let info = analysis.get(&loop_id).expect("loop should be analyzed");
685 assert!(!info.vectorizable, "small loop should not be vectorizable");
686 }
687
688 #[test]
689 fn test_vectorization_transform() {
690 let (mut ir, _loop_id) = make_vectorizable_loop(1024);
691
692 let mut pass = VectorizePass::new(VectorizeConfig::default());
693 pass.analyze(&ir);
694 let report = pass
695 .vectorize(&mut ir)
696 .expect("vectorization should succeed");
697
698 assert!(report.any_vectorized(), "should have vectorized loops");
699 assert_eq!(report.count(), 1, "should have vectorized 1 loop");
700 }
701
702 #[test]
703 fn test_simd_intrinsic_names() {
704 assert_eq!(
706 SimdIntrinsic::Add.x86_name(ScalarType::F32, 4),
707 "_mm_add_ps"
708 );
709 assert_eq!(
710 SimdIntrinsic::Fmadd.x86_name(ScalarType::F32, 4),
711 "_mm_fmadd_ps"
712 );
713
714 assert_eq!(
716 SimdIntrinsic::Add.x86_name(ScalarType::F32, 8),
717 "_mm256_add_ps"
718 );
719 assert_eq!(
720 SimdIntrinsic::Hadd.x86_name(ScalarType::F32, 8),
721 "_mm256_hadd_ps"
722 );
723
724 assert_eq!(SimdIntrinsic::Add.arm_name(ScalarType::F32, 4), "vaddq_f32");
726 assert_eq!(
727 SimdIntrinsic::Fmadd.arm_name(ScalarType::F32, 4),
728 "vfmaq_f32"
729 );
730 }
731
732 #[test]
733 fn test_target_vector_widths() {
734 assert_eq!(
736 LoopType::natural_vector_width(ScalarType::F32, TargetArch::X86_64Avx2),
737 8
738 );
739
740 assert_eq!(
742 LoopType::natural_vector_width(ScalarType::F32, TargetArch::X86_64Sse2),
743 4
744 );
745
746 assert_eq!(
748 LoopType::natural_vector_width(ScalarType::F32, TargetArch::Aarch64Neon),
749 4
750 );
751
752 assert_eq!(
754 LoopType::natural_vector_width(ScalarType::F64, TargetArch::X86_64Avx2),
755 4
756 );
757 }
758
759 #[test]
760 fn test_vectorize_report_display() {
761 let report = VectorizeReport {
762 vectorized_loops: vec![VectorizedLoopInfo {
763 loop_id: LoopId::new(0),
764 vector_width: 8,
765 has_fma: true,
766 has_reduction: false,
767 }],
768 failed_loops: vec![],
769 };
770
771 let output = format!("{}", report);
772 assert!(output.contains("Vectorized loops: 1"));
773 assert!(output.contains("width=8"));
774 assert!(output.contains("fma=true"));
775 }
776}